diff --git a/api/errors/errors.go b/api/errors/errors.go index 7ec06e15201..a02141ea522 100644 --- a/api/errors/errors.go +++ b/api/errors/errors.go @@ -162,3 +162,6 @@ var ErrRegisteredNFTTokenIDs = errors.New("getting registered nft token ids erro // ErrInvalidRole signals that an invalid role was provided var ErrInvalidRole = errors.New("invalid role") + +// ErrIsDataTrieMigrated signals that an error occurred while trying to verify the migration status of the data trie +var ErrIsDataTrieMigrated = errors.New("could not verify the migration status of the data trie") diff --git a/api/groups/addressGroup.go b/api/groups/addressGroup.go index 410ef8c73c3..1866c3bf022 100644 --- a/api/groups/addressGroup.go +++ b/api/groups/addressGroup.go @@ -17,26 +17,27 @@ import ( ) const ( - getAccountPath = "/:address" - getAccountsPath = "/bulk" - getBalancePath = "/:address/balance" - getUsernamePath = "/:address/username" - getCodeHashPath = "/:address/code-hash" - getKeysPath = "/:address/keys" - getKeyPath = "/:address/key/:key" - getESDTTokensPath = "/:address/esdt" - getESDTBalancePath = "/:address/esdt/:tokenIdentifier" - getESDTTokensWithRolePath = "/:address/esdts-with-role/:role" - getESDTsRolesPath = "/:address/esdts/roles" - getRegisteredNFTsPath = "/:address/registered-nfts" - getESDTNFTDataPath = "/:address/nft/:tokenIdentifier/nonce/:nonce" - getGuardianData = "/:address/guardian-data" - urlParamOnFinalBlock = "onFinalBlock" - urlParamOnStartOfEpoch = "onStartOfEpoch" - urlParamBlockNonce = "blockNonce" - urlParamBlockHash = "blockHash" - urlParamBlockRootHash = "blockRootHash" - urlParamHintEpoch = "hintEpoch" + getAccountPath = "/:address" + getAccountsPath = "/bulk" + getBalancePath = "/:address/balance" + getUsernamePath = "/:address/username" + getCodeHashPath = "/:address/code-hash" + getKeysPath = "/:address/keys" + getKeyPath = "/:address/key/:key" + getDataTrieMigrationStatusPath = "/:address/is-data-trie-migrated" + getESDTTokensPath = "/:address/esdt" + getESDTBalancePath = "/:address/esdt/:tokenIdentifier" + getESDTTokensWithRolePath = "/:address/esdts-with-role/:role" + getESDTsRolesPath = "/:address/esdts/roles" + getRegisteredNFTsPath = "/:address/registered-nfts" + getESDTNFTDataPath = "/:address/nft/:tokenIdentifier/nonce/:nonce" + getGuardianData = "/:address/guardian-data" + urlParamOnFinalBlock = "onFinalBlock" + urlParamOnStartOfEpoch = "onStartOfEpoch" + urlParamBlockNonce = "blockNonce" + urlParamBlockHash = "blockHash" + urlParamBlockRootHash = "blockRootHash" + urlParamHintEpoch = "hintEpoch" ) // addressFacadeHandler defines the methods to be implemented by a facade for handling address requests @@ -54,6 +55,7 @@ type addressFacadeHandler interface { GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) + IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) IsInterfaceNil() bool } @@ -164,6 +166,11 @@ func NewAddressGroup(facade addressFacadeHandler) (*addressGroup, error) { Method: http.MethodGet, Handler: ag.getGuardianData, }, + { + Path: getDataTrieMigrationStatusPath, + Method: http.MethodGet, + Handler: ag.isDataTrieMigrated, + }, } ag.endpoints = endpoints @@ -442,6 +449,29 @@ func (ag *addressGroup) getAllESDTData(c *gin.Context) { shared.RespondWithSuccess(c, gin.H{"esdts": formattedTokens, "blockInfo": blockInfo}) } +// isDataTrieMigrated returns true if the data trie is migrated for the given address +func (ag *addressGroup) isDataTrieMigrated(c *gin.Context) { + addr := c.Param("address") + if addr == "" { + shared.RespondWithValidationError(c, errors.ErrIsDataTrieMigrated, errors.ErrEmptyAddress) + return + } + + options, err := extractAccountQueryOptions(c) + if err != nil { + shared.RespondWithValidationError(c, errors.ErrIsDataTrieMigrated, err) + return + } + + isMigrated, err := ag.getFacade().IsDataTrieMigrated(addr, options) + if err != nil { + shared.RespondWithInternalError(c, errors.ErrIsDataTrieMigrated, err) + return + } + + shared.RespondWithSuccess(c, gin.H{"isMigrated": isMigrated}) +} + func buildTokenDataApiResponse(tokenIdentifier string, esdtData *esdt.ESDigitalToken) *esdtNFTTokenData { tokenData := &esdtNFTTokenData{ TokenIdentifier: tokenIdentifier, diff --git a/api/groups/addressGroup_test.go b/api/groups/addressGroup_test.go index ac4fd92def2..bb19bb81d2c 100644 --- a/api/groups/addressGroup_test.go +++ b/api/groups/addressGroup_test.go @@ -1150,8 +1150,93 @@ func getAddressRoutesConfig() config.ApiRoutesConfig { {Name: "/:address/nft/:tokenIdentifier/nonce/:nonce", Open: true}, {Name: "/:address/esdts-with-role/:role", Open: true}, {Name: "/:address/registered-nfts", Open: true}, + {Name: "/:address/is-data-trie-migrated", Open: true}, }, }, }, } } + +func TestIsDataTrieMigrated(t *testing.T) { + t.Parallel() + + testAddress := "address" + expectedErr := errors.New("expected error") + + t.Run("should return error if IsDataTrieMigrated returns error", func(t *testing.T) { + t.Parallel() + + facade := mock.FacadeStub{ + IsDataTrieMigratedCalled: func(address string, _ api.AccountQueryOptions) (bool, error) { + return false, expectedErr + }, + } + + addrGroup, err := groups.NewAddressGroup(&facade) + require.NoError(t, err) + ws := startWebServer(addrGroup, "address", getAddressRoutesConfig()) + + req, _ := http.NewRequest("GET", fmt.Sprintf("/address/%s/is-data-trie-migrated", testAddress), nil) + resp := httptest.NewRecorder() + ws.ServeHTTP(resp, req) + + response := shared.GenericAPIResponse{} + loadResponse(resp.Body, &response) + assert.Equal(t, http.StatusInternalServerError, resp.Code) + assert.True(t, strings.Contains(response.Error, expectedErr.Error())) + }) + + t.Run("should return true if IsDataTrieMigrated returns true", func(t *testing.T) { + t.Parallel() + + facade := mock.FacadeStub{ + IsDataTrieMigratedCalled: func(address string, _ api.AccountQueryOptions) (bool, error) { + return true, nil + }, + } + + addrGroup, err := groups.NewAddressGroup(&facade) + require.NoError(t, err) + ws := startWebServer(addrGroup, "address", getAddressRoutesConfig()) + + req, _ := http.NewRequest("GET", fmt.Sprintf("/address/%s/is-data-trie-migrated", testAddress), nil) + resp := httptest.NewRecorder() + ws.ServeHTTP(resp, req) + + response := shared.GenericAPIResponse{} + loadResponse(resp.Body, &response) + assert.Equal(t, http.StatusOK, resp.Code) + assert.True(t, response.Error == "") + + respData, ok := response.Data.(map[string]interface{}) + assert.True(t, ok) + assert.True(t, respData["isMigrated"].(bool)) + }) + + t.Run("should return false if IsDataTrieMigrated returns false", func(t *testing.T) { + t.Parallel() + + facade := mock.FacadeStub{ + IsDataTrieMigratedCalled: func(address string, _ api.AccountQueryOptions) (bool, error) { + return false, nil + }, + } + + addrGroup, err := groups.NewAddressGroup(&facade) + require.NoError(t, err) + ws := startWebServer(addrGroup, "address", getAddressRoutesConfig()) + + req, _ := http.NewRequest("GET", fmt.Sprintf("/address/%s/is-data-trie-migrated", testAddress), nil) + resp := httptest.NewRecorder() + ws.ServeHTTP(resp, req) + + response := shared.GenericAPIResponse{} + loadResponse(resp.Body, &response) + assert.Equal(t, http.StatusOK, resp.Code) + assert.True(t, response.Error == "") + + respData, ok := response.Data.(map[string]interface{}) + assert.True(t, ok) + assert.False(t, respData["isMigrated"].(bool)) + }) +} diff --git a/api/mock/facadeStub.go b/api/mock/facadeStub.go index b88c3e01709..76e52faf1a9 100644 --- a/api/mock/facadeStub.go +++ b/api/mock/facadeStub.go @@ -87,6 +87,7 @@ type FacadeStub struct { RestAPIServerDebugModeCalled func() bool PprofEnabledCalled func() bool DecodeAddressPubkeyCalled func(pk string) ([]byte, error) + IsDataTrieMigratedCalled func(address string, options api.AccountQueryOptions) (bool, error) } // GetTokenSupply - @@ -553,6 +554,15 @@ func (f *FacadeStub) GetInternalStartOfEpochValidatorsInfo(epoch uint32) ([]*sta return nil, nil } +// IsDataTrieMigrated - +func (f *FacadeStub) IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) { + if f.IsDataTrieMigratedCalled != nil { + return f.IsDataTrieMigratedCalled(address, options) + } + + return false, nil +} + // Trigger - func (f *FacadeStub) Trigger(_ uint32, _ bool) error { return nil diff --git a/api/shared/interface.go b/api/shared/interface.go index f0a9a98359e..c9c69e3c009 100644 --- a/api/shared/interface.go +++ b/api/shared/interface.go @@ -126,5 +126,6 @@ type FacadeHandler interface { GetTransactionsPoolForSender(sender, fields string) (*common.TransactionsPoolForSenderApiResponse, error) GetLastPoolNonceForSender(sender string) (uint64, error) GetTransactionsPoolNonceGapsForSender(sender string) (*common.TransactionsPoolNonceGapsForSenderApiResponse, error) + IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) IsInterfaceNil() bool } diff --git a/cmd/node/config/api.toml b/cmd/node/config/api.toml index 87b59649910..aa48d8a367e 100644 --- a/cmd/node/config/api.toml +++ b/cmd/node/config/api.toml @@ -83,7 +83,10 @@ { Name = "/:address/esdts-with-role/:role", Open = true }, # /address/:address/registered-nfts will return the token identifiers of the tokens registered by the address - { Name = "/:address/registered-nfts", Open = true } + { Name = "/:address/registered-nfts", Open = true }, + + # /address/:address/is-data-trie-migrated will return the status of the data trie migration for the given address + { Name = "/:address/is-data-trie-migrated", Open = true } ] [APIPackages.hardfork] diff --git a/cmd/node/config/enableEpochs.toml b/cmd/node/config/enableEpochs.toml index 959b5a361b4..2255ec8f080 100644 --- a/cmd/node/config/enableEpochs.toml +++ b/cmd/node/config/enableEpochs.toml @@ -224,6 +224,9 @@ # RuntimeMemStoreLimitEnableEpoch represents the epoch when the condition for Runtime MemStore is enabled RuntimeMemStoreLimitEnableEpoch = 1 + # AutoBalanceDataTriesEnableEpoch represents the epoch when the data tries are automatically balanced by inserting at the hashed key instead of the normal key + AutoBalanceDataTriesEnableEpoch = 5 + # SetSenderInEeiOutputTransferEnableEpoch represents the epoch when setting the sender in eei output transfers will be enabled SetSenderInEeiOutputTransferEnableEpoch = 1 diff --git a/cmd/node/config/gasSchedules/gasScheduleV1.toml b/cmd/node/config/gasSchedules/gasScheduleV1.toml index 66be1a6474b..2f2b4d17f79 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV1.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV1.toml @@ -18,6 +18,8 @@ SetGuardian = 250000 GuardAccount = 250000 UnGuardAccount = 250000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV2.toml b/cmd/node/config/gasSchedules/gasScheduleV2.toml index bfeb9a2595e..2d3d06f36a8 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV2.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV2.toml @@ -18,6 +18,8 @@ SetGuardian = 250000 GuardAccount = 250000 UnGuardAccount = 250000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV3.toml b/cmd/node/config/gasSchedules/gasScheduleV3.toml index 09a29cccbe0..d2b779c61c6 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV3.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV3.toml @@ -18,6 +18,8 @@ SetGuardian = 250000 GuardAccount = 250000 UnGuardAccount = 250000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV4.toml b/cmd/node/config/gasSchedules/gasScheduleV4.toml index 6de1f466876..4059fadada2 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV4.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV4.toml @@ -18,6 +18,8 @@ SetGuardian = 250000 GuardAccount = 250000 UnGuardAccount = 250000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV5.toml b/cmd/node/config/gasSchedules/gasScheduleV5.toml index 634275b1cd9..27d04e6ff5f 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV5.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV5.toml @@ -18,6 +18,8 @@ SetGuardian = 250000 GuardAccount = 250000 UnGuardAccount = 250000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV6.toml b/cmd/node/config/gasSchedules/gasScheduleV6.toml index 09229b8f15f..cdff4012b71 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV6.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV6.toml @@ -18,6 +18,8 @@ SetGuardian = 250000 GuardAccount = 250000 UnGuardAccount = 250000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV7.toml b/cmd/node/config/gasSchedules/gasScheduleV7.toml index 3f31ac9969c..2a8a9ced22a 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV7.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV7.toml @@ -19,6 +19,8 @@ SetGuardian = 250000 GuardAccount = 250000 UnGuardAccount = 250000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/common/converters_test.go b/common/converters_test.go index a9f215cf9f0..43ca3e159fc 100644 --- a/common/converters_test.go +++ b/common/converters_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/common/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -122,7 +123,7 @@ func TestCalculateHash_Good(t *testing.T) { marshaledData := "marshalized random string" hashedData := "hashed marshalized random string" hash, err := core.CalculateHash( - &testscommon.MarshalizerStub{ + &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { marshalCalled = true assert.Equal(t, initialObject, obj) diff --git a/common/enablers/enableEpochsHandler.go b/common/enablers/enableEpochsHandler.go index cac29504579..2c5a41868f7 100644 --- a/common/enablers/enableEpochsHandler.go +++ b/common/enablers/enableEpochsHandler.go @@ -123,6 +123,7 @@ func (handler *enableEpochsHandler) EpochConfirmed(epoch uint32, _ uint64) { handler.setFlagValue(epoch >= handler.enableEpochsConfig.KeepExecOrderOnCreatedSCRsEnableEpoch, handler.keepExecOrderOnCreatedSCRsFlag, "keepExecOrderOnCreatedSCRsFlag", epoch, handler.enableEpochsConfig.KeepExecOrderOnCreatedSCRsEnableEpoch) handler.setFlagValue(epoch >= handler.enableEpochsConfig.ChangeUsernameEnableEpoch, handler.changeUsernameFlag, "changeUsername", epoch, handler.enableEpochsConfig.ChangeUsernameEnableEpoch) handler.setFlagValue(epoch >= handler.enableEpochsConfig.ConsistentTokensValuesLengthCheckEnableEpoch, handler.consistentTokensValuesCheckFlag, "consistentTokensValuesCheckFlag", epoch, handler.enableEpochsConfig.ConsistentTokensValuesLengthCheckEnableEpoch) + handler.setFlagValue(epoch >= handler.enableEpochsConfig.AutoBalanceDataTriesEnableEpoch, handler.autoBalanceDataTriesFlag, "autoBalanceDataTriesFlag", epoch, handler.enableEpochsConfig.AutoBalanceDataTriesEnableEpoch) } func (handler *enableEpochsHandler) setFlagValue(value bool, flag *atomic.Flag, flagName string, epoch uint32, flagEpoch uint32) { diff --git a/common/enablers/enableEpochsHandler_test.go b/common/enablers/enableEpochsHandler_test.go index 76589bebc95..e3528e8e11c 100644 --- a/common/enablers/enableEpochsHandler_test.go +++ b/common/enablers/enableEpochsHandler_test.go @@ -94,6 +94,7 @@ func createEnableEpochsConfig() config.EnableEpochs { MultiClaimOnDelegationEnableEpoch: 78, KeepExecOrderOnCreatedSCRsEnableEpoch: 79, ChangeUsernameEnableEpoch: 80, + AutoBalanceDataTriesEnableEpoch: 81, } } @@ -220,11 +221,12 @@ func TestNewEnableEpochsHandler_EpochConfirmed(t *testing.T) { assert.False(t, handler.IsKeepExecOrderOnCreatedSCRsEnabled()) assert.False(t, handler.IsMultiClaimOnDelegationEnabled()) assert.False(t, handler.IsChangeUsernameEnabled()) + assert.False(t, handler.IsAutoBalanceDataTriesEnabled()) }) t.Run("flags with == condition should be set, along with all >=", func(t *testing.T) { t.Parallel() - epoch := uint32(80) + epoch := uint32(81) cfg := createEnableEpochsConfig() cfg.StakingV2EnableEpoch = epoch cfg.ESDTEnableEpoch = epoch @@ -322,6 +324,7 @@ func TestNewEnableEpochsHandler_EpochConfirmed(t *testing.T) { assert.True(t, handler.IsRuntimeCodeSizeFixEnabled()) assert.True(t, handler.IsKeepExecOrderOnCreatedSCRsEnabled()) assert.True(t, handler.IsChangeUsernameEnabled()) + assert.True(t, handler.IsAutoBalanceDataTriesEnabled()) }) t.Run("flags with < should be set", func(t *testing.T) { t.Parallel() @@ -419,5 +422,6 @@ func TestNewEnableEpochsHandler_EpochConfirmed(t *testing.T) { assert.False(t, handler.IsRuntimeCodeSizeFixEnabled()) assert.False(t, handler.IsKeepExecOrderOnCreatedSCRsEnabled()) assert.False(t, handler.IsChangeUsernameEnabled()) + assert.False(t, handler.IsAutoBalanceDataTriesEnabled()) }) } diff --git a/common/enablers/epochFlags.go b/common/enablers/epochFlags.go index 6f0efe9c0d1..a92b961aea4 100644 --- a/common/enablers/epochFlags.go +++ b/common/enablers/epochFlags.go @@ -95,6 +95,7 @@ type epochFlagsHolder struct { multiClaimOnDelegationFlag *atomic.Flag changeUsernameFlag *atomic.Flag consistentTokensValuesCheckFlag *atomic.Flag + autoBalanceDataTriesFlag *atomic.Flag } func newEpochFlagsHolder() *epochFlagsHolder { @@ -189,6 +190,7 @@ func newEpochFlagsHolder() *epochFlagsHolder { consistentTokensValuesCheckFlag: &atomic.Flag{}, multiClaimOnDelegationFlag: &atomic.Flag{}, changeUsernameFlag: &atomic.Flag{}, + autoBalanceDataTriesFlag: &atomic.Flag{}, } } @@ -694,3 +696,8 @@ func (holder *epochFlagsHolder) IsMultiClaimOnDelegationEnabled() bool { func (holder *epochFlagsHolder) IsChangeUsernameEnabled() bool { return holder.changeUsernameFlag.IsSet() } + +// IsAutoBalanceDataTriesEnabled returns true if autoBalanceDataTriesFlag is enabled +func (holder *epochFlagsHolder) IsAutoBalanceDataTriesEnabled() bool { + return holder.autoBalanceDataTriesFlag.IsSet() +} diff --git a/common/interface.go b/common/interface.go index fb9d83a0150..eac852a0111 100644 --- a/common/interface.go +++ b/common/interface.go @@ -8,7 +8,6 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" - "github.com/multiversx/mx-chain-go/trie/statistics" ) // TrieIteratorChannels defines the channels that are being used when iterating the trie nodes @@ -17,6 +16,17 @@ type TrieIteratorChannels struct { ErrChan BufferedErrChan } +// TrieType defines the type of the trie +type TrieType string + +const ( + // MainTrie represents the main trie in which all the accounts and SC code are stored + MainTrie TrieType = "mainTrie" + + // DataTrie represents a data trie in which all the data related to an account is stored + DataTrie TrieType = "dataTrie" +) + // BufferedErrChan is an interface that defines the methods for a buffered error channel type BufferedErrChan interface { WriteInChanNonBlocking(err error) @@ -40,18 +50,25 @@ type Trie interface { GetOldRoot() []byte GetSerializedNodes([]byte, uint64) ([][]byte, uint64, error) GetSerializedNode([]byte) ([]byte, error) - GetAllLeavesOnChannel(allLeavesChan *TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder KeyBuilder) error + GetAllLeavesOnChannel(allLeavesChan *TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder KeyBuilder, trieLeafParser TrieLeafParser) error GetAllHashes() ([][]byte, error) GetProof(key []byte) ([][]byte, []byte, error) VerifyProof(rootHash []byte, key []byte, proof [][]byte) (bool, error) GetStorageManager() StorageManager + IsMigratedToLatestVersion() (bool, error) Close() error IsInterfaceNil() bool } +// TrieLeafParser is used to parse trie leaves +type TrieLeafParser interface { + ParseLeaf(key []byte, val []byte, version core.TrieNodeVersion) (core.KeyValueHolder, error) + IsInterfaceNil() bool +} + // TrieStats is used to collect the trie statistics for the given rootHash type TrieStats interface { - GetTrieStats(address string, rootHash []byte) (*statistics.TrieStatsDTO, error) + GetTrieStats(address string, rootHash []byte) (TrieStatisticsHandler, error) } // StorageMarker is used to mark the given storer as synced and active @@ -65,12 +82,14 @@ type KeyBuilder interface { BuildKey(keyPart []byte) GetKey() ([]byte, error) Clone() KeyBuilder + IsInterfaceNil() bool } // DataTrieHandler is an interface that declares the methods used for dataTries type DataTrieHandler interface { RootHash() ([]byte, error) - GetAllLeavesOnChannel(leavesChannels *TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder KeyBuilder) error + GetAllLeavesOnChannel(leavesChannels *TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder KeyBuilder, trieLeafParser TrieLeafParser) error + IsMigratedToLatestVersion() (bool, error) IsInterfaceNil() bool } @@ -162,21 +181,36 @@ type SnapshotStatisticsHandler interface { SnapshotFinished() NewSnapshotStarted() WaitForSnapshotsToFinish() - AddTrieStats(*statistics.TrieStatsDTO) + AddTrieStats(handler TrieStatisticsHandler, trieType TrieType) + IsInterfaceNil() bool } // TrieStatisticsHandler is used to collect different statistics about a single trie type TrieStatisticsHandler interface { AddBranchNode(level int, size uint64) AddExtensionNode(level int, size uint64) - AddLeafNode(level int, size uint64) + AddLeafNode(level int, size uint64, version core.TrieNodeVersion) AddAccountInfo(address string, rootHash []byte) - GetTrieStats() *statistics.TrieStatsDTO + + GetTotalNodesSize() uint64 + GetTotalNumNodes() uint64 + GetMaxTrieDepth() uint32 + GetBranchNodesSize() uint64 + GetNumBranchNodes() uint64 + GetExtensionNodesSize() uint64 + GetNumExtensionNodes() uint64 + GetLeafNodesSize() uint64 + GetNumLeafNodes() uint64 + GetLeavesMigrationStats() map[core.TrieNodeVersion]uint64 + + MergeTriesStatistics(statsToBeMerged TrieStatisticsHandler) + ToString() []string + IsInterfaceNil() bool } // TriesStatisticsCollector is used to merge the statistics for multiple tries type TriesStatisticsCollector interface { - Add(trieStats *statistics.TrieStatsDTO) + Add(trieStats TrieStatisticsHandler, trieType TrieType) Print() GetNumNodes() uint64 } @@ -351,6 +385,7 @@ type EnableEpochsHandler interface { IsMultiClaimOnDelegationEnabled() bool IsChangeUsernameEnabled() bool IsConsistentTokensValuesLengthCheckEnabled() bool + IsAutoBalanceDataTriesEnabled() bool IsInterfaceNil() bool } diff --git a/common/trie.go b/common/trie.go index eeda9925561..510029a1dc7 100644 --- a/common/trie.go +++ b/common/trie.go @@ -1,6 +1,10 @@ package common -import "bytes" +import ( + "bytes" + + "github.com/multiversx/mx-chain-core-go/core" +) // EmptyTrieHash returns the value with empty trie hash var EmptyTrieHash = make([]byte, 32) @@ -15,3 +19,17 @@ func IsEmptyTrie(root []byte) bool { } return false } + +// TrimSuffixFromValue returns the value without the suffix +func TrimSuffixFromValue(value []byte, suffixLength int) ([]byte, error) { + if suffixLength == 0 { + return value, nil + } + + dataLength := len(value) - suffixLength + if dataLength < 0 { + return nil, core.ErrSuffixNotPresentOrInIncorrectPosition + } + + return value[:dataLength], nil +} diff --git a/config/config.go b/config/config.go index 7ffd1e35d37..1f429ad2de3 100644 --- a/config/config.go +++ b/config/config.go @@ -21,12 +21,14 @@ type HeadersPoolConfig struct { // DBConfig will map the database configuration type DBConfig struct { - FilePath string - Type string - BatchDelaySeconds int - MaxBatchSize int - MaxOpenFiles int - UseTmpAsFilePath bool + FilePath string + Type string + BatchDelaySeconds int + MaxBatchSize int + MaxOpenFiles int + UseTmpAsFilePath bool + ShardIDProviderType string + NumShards int32 } // StorageConfig will map the storage unit configuration diff --git a/config/epochConfig.go b/config/epochConfig.go index 838af4b95bf..c4a7efbebcc 100644 --- a/config/epochConfig.go +++ b/config/epochConfig.go @@ -96,6 +96,7 @@ type EnableEpochs struct { KeepExecOrderOnCreatedSCRsEnableEpoch uint32 MultiClaimOnDelegationEnableEpoch uint32 ChangeUsernameEnableEpoch uint32 + AutoBalanceDataTriesEnableEpoch uint32 BLSMultiSignerEnableEpoch []MultiSignerConfig SetGuardianEnableEpoch uint32 ConsistentTokensValuesLengthCheckEnableEpoch uint32 diff --git a/dataRetriever/factory/storageRequestersContainer/args.go b/dataRetriever/factory/storageRequestersContainer/args.go index 70a2db6501e..2e498ba6f15 100644 --- a/dataRetriever/factory/storageRequestersContainer/args.go +++ b/dataRetriever/factory/storageRequestersContainer/args.go @@ -5,6 +5,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/sharding" @@ -26,4 +27,5 @@ type FactoryArgs struct { ManualEpochStartNotifier dataRetriever.ManualEpochStartNotifier ChanGracefullyClose chan endProcess.ArgEndProcess SnapshotsEnabled bool + EnableEpochsHandler common.EnableEpochsHandler } diff --git a/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go b/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go index d7407bdb1ba..e0bccf04e75 100644 --- a/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go +++ b/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" disabledRequesters "github.com/multiversx/mx-chain-go/dataRetriever/requestHandlers/requesters/disabled" "github.com/multiversx/mx-chain-go/dataRetriever/storageRequesters" + "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/storage" @@ -34,6 +35,7 @@ type baseRequestersContainerFactory struct { uint64ByteSliceConverter typeConverters.Uint64ByteSliceConverter dataPacker dataRetriever.DataPacker manualEpochStartNotifier dataRetriever.ManualEpochStartNotifier + enableEpochsHandler common.EnableEpochsHandler chanGracefullyClose chan endProcess.ArgEndProcess generalConfig config.Config shardIDForTries uint32 @@ -70,6 +72,9 @@ func (brcf *baseRequestersContainerFactory) checkParams() error { if check.IfNil(brcf.hasher) { return dataRetriever.ErrNilHasher } + if check.IfNil(brcf.enableEpochsHandler) { + return errors.ErrNilEnableEpochsHandler + } return nil } @@ -232,6 +237,7 @@ func (brcf *baseRequestersContainerFactory) newImportDBTrieStorage( mainStorer storage.Storer, checkpointsStorer storage.Storer, storageIdentifier dataRetriever.UnitType, + handler common.EnableEpochsHandler, ) (common.StorageManager, dataRetriever.TrieDataGetter, error) { pathManager, err := storageFactory.CreatePathManager( storageFactory.ArgCreatePathManager{ @@ -263,6 +269,7 @@ func (brcf *baseRequestersContainerFactory) newImportDBTrieStorage( SnapshotsEnabled: brcf.snapshotsEnabled, IdleProvider: disabled.NewProcessStatusHandler(), Identifier: storageIdentifier.String(), + EnableEpochsHandler: handler, } return trieFactoryInstance.Create(args) } diff --git a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go index 092ef541a5c..5ff8809a81b 100644 --- a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go +++ b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go @@ -37,6 +37,7 @@ func NewMetaRequestersContainerFactory( chainID: args.ChainID, workingDir: args.WorkingDirectory, snapshotsEnabled: args.SnapshotsEnabled, + enableEpochsHandler: args.EnableEpochsHandler, } err := base.checkParams() @@ -195,6 +196,7 @@ func (mrcf *metaRequestersContainerFactory) generateTrieNodesRequesters() error userAccountsStorer, userAccountsCheckpointStorer, dataRetriever.UserAccountsUnit, + mrcf.enableEpochsHandler, ) if err != nil { return fmt.Errorf("%w while creating user accounts data trie storage getter", err) @@ -232,6 +234,7 @@ func (mrcf *metaRequestersContainerFactory) generateTrieNodesRequesters() error peerAccountsStorer, peerAccountsCheckpointStorer, dataRetriever.PeerAccountsUnit, + mrcf.enableEpochsHandler, ) if err != nil { return fmt.Errorf("%w while creating peer accounts data trie storage getter", err) diff --git a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go index a53aca90aaf..6711a3d58c4 100644 --- a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go +++ b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" @@ -224,5 +225,6 @@ func getArgumentsMeta() storagerequesterscontainer.FactoryArgs { ManualEpochStartNotifier: &mock.ManualEpochStartNotifierStub{}, ChanGracefullyClose: make(chan endProcess.ArgEndProcess), SnapshotsEnabled: true, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } diff --git a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go index dcf0acf6583..59aab96c7ce 100644 --- a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go +++ b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go @@ -37,6 +37,7 @@ func NewShardRequestersContainerFactory( chainID: args.ChainID, workingDir: args.WorkingDirectory, snapshotsEnabled: args.SnapshotsEnabled, + enableEpochsHandler: args.EnableEpochsHandler, } err := base.checkParams() @@ -170,6 +171,7 @@ func (srcf *shardRequestersContainerFactory) generateTrieNodesRequesters() error userAccountsStorer, userAccountsCheckpointStorer, dataRetriever.UserAccountsUnit, + srcf.enableEpochsHandler, ) if err != nil { return fmt.Errorf("%w while creating user accounts data trie storage getter", err) diff --git a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go index 71319735278..cc7a22af6c8 100644 --- a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go +++ b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" @@ -209,5 +210,6 @@ func getArgumentsShard() storagerequesterscontainer.FactoryArgs { ManualEpochStartNotifier: &mock.ManualEpochStartNotifierStub{}, ChanGracefullyClose: make(chan endProcess.ArgEndProcess), SnapshotsEnabled: true, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } diff --git a/dataRetriever/provider/miniBlocks_test.go b/dataRetriever/provider/miniBlocks_test.go index 3ccbeba3490..dc0e4f206e8 100644 --- a/dataRetriever/provider/miniBlocks_test.go +++ b/dataRetriever/provider/miniBlocks_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/dataRetriever/provider" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -247,7 +248,7 @@ func TestMiniBlockProvider_GetMiniBlocksFromStorerShouldBeFoundInStorage(t *test cnt := 0 arg := createMockMiniblockProviderArgs(nil, existingHashes) - arg.Marshalizer = &testscommon.MarshalizerStub{ + arg.Marshalizer = &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { cnt++ if cnt == 1 { diff --git a/dataRetriever/requestHandlers/requesters/baseRequester_test.go b/dataRetriever/requestHandlers/requesters/baseRequester_test.go index 7ec2f425b3f..10be2adfcae 100644 --- a/dataRetriever/requestHandlers/requesters/baseRequester_test.go +++ b/dataRetriever/requestHandlers/requesters/baseRequester_test.go @@ -6,15 +6,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" - "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/assert" ) func createMockArgBaseRequester() ArgBaseRequester { return ArgBaseRequester{ RequestSender: &dataRetrieverMocks.TopicRequestSenderStub{}, - Marshaller: &testscommon.MarshalizerStub{}, + Marshaller: &marshallerMock.MarshalizerStub{}, } } @@ -33,7 +33,7 @@ func Test_checkArgBase(t *testing.T) { err := checkArgBase(ArgBaseRequester{ RequestSender: nil, - Marshaller: &testscommon.MarshalizerStub{}, + Marshaller: &marshallerMock.MarshalizerStub{}, }) assert.Equal(t, err, dataRetriever.ErrNilRequestSender) }) @@ -73,7 +73,7 @@ func TestBaseRequester_RequestDataFromHash(t *testing.T) { } baseHandler := createBaseRequester(ArgBaseRequester{ RequestSender: requestSender, - Marshaller: &testscommon.MarshalizerStub{}, + Marshaller: &marshallerMock.MarshalizerStub{}, }) assert.False(t, check.IfNilReflect(baseHandler)) @@ -99,7 +99,7 @@ func TestBaseRequester_NumPeersToQuery(t *testing.T) { } baseHandler := createBaseRequester(ArgBaseRequester{ RequestSender: requestSender, - Marshaller: &testscommon.MarshalizerStub{}, + Marshaller: &marshallerMock.MarshalizerStub{}, }) assert.False(t, check.IfNilReflect(baseHandler)) @@ -126,7 +126,7 @@ func TestBaseRequester_SetDebugHandler(t *testing.T) { } baseHandler := createBaseRequester(ArgBaseRequester{ RequestSender: requestSender, - Marshaller: &testscommon.MarshalizerStub{}, + Marshaller: &marshallerMock.MarshalizerStub{}, }) assert.False(t, check.IfNilReflect(baseHandler)) diff --git a/dataRetriever/resolvers/validatorInfoResolver_test.go b/dataRetriever/resolvers/validatorInfoResolver_test.go index 0d5916c710e..19f659660f9 100644 --- a/dataRetriever/resolvers/validatorInfoResolver_test.go +++ b/dataRetriever/resolvers/validatorInfoResolver_test.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -209,19 +210,19 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { t.Run("data found in cache but marshal fails", func(t *testing.T) { t.Parallel() - marshallerMock := testscommon.MarshalizerMock{} + marshMock := marshallerMock.MarshalizerMock{} args := createMockArgValidatorInfoResolver() args.ValidatorInfoPool = &testscommon.ShardedDataStub{ SearchFirstDataCalled: func(key []byte) (value interface{}, ok bool) { return []byte("some value"), true }, } - args.Marshaller = &testscommon.MarshalizerStub{ + args.Marshaller = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, UnmarshalCalled: func(obj interface{}, buff []byte) error { - return marshallerMock.Unmarshal(obj, buff) + return marshMock.Unmarshal(obj, buff) }, } res, _ := resolvers.NewValidatorInfoResolver(args) @@ -233,7 +234,7 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { t.Run("data found in storage but marshal fails", func(t *testing.T) { t.Parallel() - marshallerMock := testscommon.MarshalizerMock{} + marshMock := marshallerMock.MarshalizerMock{} args := createMockArgValidatorInfoResolver() args.ValidatorInfoPool = &testscommon.ShardedDataStub{ SearchFirstDataCalled: func(key []byte) (value interface{}, ok bool) { @@ -245,12 +246,12 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { return []byte("some value"), nil }, } - args.Marshaller = &testscommon.MarshalizerStub{ + args.Marshaller = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, UnmarshalCalled: func(obj interface{}, buff []byte) error { - return marshallerMock.Unmarshal(obj, buff) + return marshMock.Unmarshal(obj, buff) }, } res, _ := resolvers.NewValidatorInfoResolver(args) @@ -272,12 +273,12 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { } args.SenderResolver = &mock.TopicResolverSenderStub{ SendCalled: func(buff []byte, peer core.PeerID) error { - marshallerMock := testscommon.MarshalizerMock{} + marshMock := marshallerMock.MarshalizerMock{} b := &batch.Batch{} - _ = marshallerMock.Unmarshal(b, buff) + _ = marshMock.Unmarshal(b, buff) vi := &state.ValidatorInfo{} - _ = marshallerMock.Unmarshal(vi, b.Data[0]) + _ = marshMock.Unmarshal(vi, b.Data[0]) assert.Equal(t, &providedValue, vi) wasCalled = true @@ -305,18 +306,18 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { } args.ValidatorInfoStorage = &storage.StorerStub{ SearchFirstCalled: func(key []byte) ([]byte, error) { - marshallerMock := testscommon.MarshalizerMock{} - return marshallerMock.Marshal(providedValue) + marshMock := marshallerMock.MarshalizerMock{} + return marshMock.Marshal(providedValue) }, } args.SenderResolver = &mock.TopicResolverSenderStub{ SendCalled: func(buff []byte, peer core.PeerID) error { - marshallerMock := testscommon.MarshalizerMock{} + marshMock := marshallerMock.MarshalizerMock{} b := &batch.Batch{} - _ = marshallerMock.Unmarshal(b, buff) + _ = marshMock.Unmarshal(b, buff) vi := &state.ValidatorInfo{} - _ = marshallerMock.Unmarshal(vi, b.Data[0]) + _ = marshMock.Unmarshal(vi, b.Data[0]) assert.Equal(t, &providedValue, vi) wasCalled = true @@ -337,11 +338,11 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { t.Parallel() args := createMockArgValidatorInfoResolver() - args.Marshaller = &testscommon.MarshalizerStub{ + args.Marshaller = &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { switch obj.(type) { case *dataRetriever.RequestData: - return testscommon.MarshalizerMock{}.Unmarshal(obj, buff) + return marshallerMock.MarshalizerMock{}.Unmarshal(obj, buff) case *batch.Batch: return expectedErr } @@ -466,14 +467,14 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { } args.SenderResolver = &mock.TopicResolverSenderStub{ SendCalled: func(buff []byte, peer core.PeerID) error { - marshallerMock := testscommon.MarshalizerMock{} + marshMock := marshallerMock.MarshalizerMock{} b := &batch.Batch{} - _ = marshallerMock.Unmarshal(b, buff) + _ = marshMock.Unmarshal(b, buff) assert.Equal(t, numOfProvidedData, len(b.Data)) for i := 0; i < numOfProvidedData; i++ { vi := &state.ValidatorInfo{} - _ = marshallerMock.Unmarshal(vi, b.Data[i]) + _ = marshMock.Unmarshal(vi, b.Data[i]) assert.Equal(t, &providedData[i], vi) } @@ -499,7 +500,7 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { providedHashes := make([][]byte, 0) providedData := make([]state.ValidatorInfo, 0) testHasher := hashingMocks.HasherMock{} - testMarshaller := testscommon.MarshalizerMock{} + testMarshaller := marshallerMock.MarshalizerMock{} providedDataMap := make(map[string]struct{}, 0) for i := 0; i < numOfProvidedData; i++ { hashStr := fmt.Sprintf("hash%d", i) @@ -524,14 +525,14 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { numOfCallsSend := 0 args.SenderResolver = &mock.TopicResolverSenderStub{ SendCalled: func(buff []byte, peer core.PeerID) error { - marshallerMock := testscommon.MarshalizerMock{} + marshMock := marshallerMock.MarshalizerMock{} b := &batch.Batch{} - _ = marshallerMock.Unmarshal(b, buff) + _ = marshMock.Unmarshal(b, buff) dataLen := len(b.Data) for i := 0; i < dataLen; i++ { vi := &state.ValidatorInfo{} - _ = marshallerMock.Unmarshal(vi, b.Data[i]) + _ = marshMock.Unmarshal(vi, b.Data[i]) // remove this info from the provided map validatorInfoBuff, err := testMarshaller.Marshal(vi) diff --git a/dblookupext/esdtSupply/esdtSuppliesProcessor_test.go b/dblookupext/esdtSupply/esdtSuppliesProcessor_test.go index 8505c7543a4..56084ced28b 100644 --- a/dblookupext/esdtSupply/esdtSuppliesProcessor_test.go +++ b/dblookupext/esdtSupply/esdtSuppliesProcessor_test.go @@ -14,6 +14,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/require" ) @@ -33,13 +34,13 @@ func TestNewSuppliesProcessor(t *testing.T) { _, err := NewSuppliesProcessor(nil, &storageStubs.StorerStub{}, &storageStubs.StorerStub{}) require.Equal(t, core.ErrNilMarshalizer, err) - _, err = NewSuppliesProcessor(&testscommon.MarshalizerMock{}, nil, &storageStubs.StorerStub{}) + _, err = NewSuppliesProcessor(&marshallerMock.MarshalizerMock{}, nil, &storageStubs.StorerStub{}) require.Equal(t, core.ErrNilStore, err) - _, err = NewSuppliesProcessor(&testscommon.MarshalizerMock{}, &storageStubs.StorerStub{}, nil) + _, err = NewSuppliesProcessor(&marshallerMock.MarshalizerMock{}, &storageStubs.StorerStub{}, nil) require.Equal(t, core.ErrNilStore, err) - proc, err := NewSuppliesProcessor(&testscommon.MarshalizerMock{}, &storageStubs.StorerStub{}, &storageStubs.StorerStub{}) + proc, err := NewSuppliesProcessor(&marshallerMock.MarshalizerMock{}, &storageStubs.StorerStub{}, &storageStubs.StorerStub{}) require.Nil(t, err) require.NotNil(t, proc) require.False(t, proc.IsInterfaceNil()) @@ -102,7 +103,7 @@ func TestProcessLogsSaveSupply(t *testing.T) { } putCalledNum := 0 - marshalizer := testscommon.MarshalizerMock{} + marshalizer := marshallerMock.MarshalizerMock{} suppliesStorer := &storageStubs.StorerStub{ GetCalled: func(key []byte) ([]byte, error) { if string(key) == "processed-block" { @@ -230,7 +231,7 @@ func TestProcessLogsSaveSupplyShouldUpdateSupplyMintedAndBurned(t *testing.T) { } membDB := testscommon.NewMemDbMock() - marshalizer := testscommon.MarshalizerMock{} + marshalizer := marshallerMock.MarshalizerMock{} numTimesCalled := 0 suppliesStorer := &storageStubs.StorerStub{ GetCalled: func(key []byte) ([]byte, error) { @@ -361,7 +362,7 @@ func TestProcessLogs_RevertChangesShouldWorkForRevertingMinting(t *testing.T) { }, } - marshalizer := testscommon.MarshalizerMock{} + marshalizer := marshallerMock.MarshalizerMock{} logsStorer := genericMocks.NewStorerMockWithErrKeyNotFound(0) mintLogToBeRevertedBytes, err := marshalizer.Marshal(mintLogToBeReverted) @@ -452,7 +453,7 @@ func TestProcessLogs_RevertChangesShouldWorkForRevertingBurning(t *testing.T) { }, } - marshalizer := testscommon.MarshalizerMock{} + marshalizer := marshallerMock.MarshalizerMock{} logsStorer := genericMocks.NewStorerMockWithErrKeyNotFound(0) mintLogToBeRevertedBytes, err := marshalizer.Marshal(mintLogToBeReverted) @@ -525,7 +526,7 @@ func getSupplyESDT(marshalizer marshal.Marshalizer, data []byte) SupplyESDT { func TestSupplyESDT_GetSupply(t *testing.T) { t.Parallel() - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} proc, _ := NewSuppliesProcessor(marshalizer, &storageStubs.StorerStub{ GetCalled: func(key []byte) ([]byte, error) { if string(key) == "my-token" { diff --git a/dblookupext/esdtSupply/logsGetter_test.go b/dblookupext/esdtSupply/logsGetter_test.go index 6ccdd0492cf..5e2359def5b 100644 --- a/dblookupext/esdtSupply/logsGetter_test.go +++ b/dblookupext/esdtSupply/logsGetter_test.go @@ -7,7 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/transaction" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/require" ) @@ -15,7 +15,7 @@ import ( func TestGetLogsBasedOnBody(t *testing.T) { t.Parallel() - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} txHash := []byte("txHash") scrHash := []byte("scrHash") @@ -62,7 +62,7 @@ func TestGetLogsBasedOnBody(t *testing.T) { func TestGetLogsWrongBodyType(t *testing.T) { t.Parallel() - getter := newLogsGetter(&testscommon.MarshalizerMock{}, &storageStubs.StorerStub{}) + getter := newLogsGetter(&marshallerMock.MarshalizerMock{}, &storageStubs.StorerStub{}) _, err := getter.getLogsBasedOnBody(nil) require.Equal(t, errCannotCastToBlockBody, err) diff --git a/dblookupext/esdtSupply/logsProcessor_test.go b/dblookupext/esdtSupply/logsProcessor_test.go index 8de850c1f6c..6512e1d28ba 100644 --- a/dblookupext/esdtSupply/logsProcessor_test.go +++ b/dblookupext/esdtSupply/logsProcessor_test.go @@ -9,7 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -69,7 +69,7 @@ func TestProcessLogsSaveSupplyNothingInStorage(t *testing.T) { } putCalledNum := 0 - marshalizer := testscommon.MarshalizerMock{} + marshalizer := marshallerMock.MarshalizerMock{} storer := &storageStubs.StorerStub{ GetCalled: func(key []byte) ([]byte, error) { return nil, storage.ErrKeyNotFound @@ -128,7 +128,7 @@ func TestTestProcessLogsSaveSupplyExistsInStorage(t *testing.T) { }, } - marshalizer := testscommon.MarshalizerMock{} + marshalizer := marshallerMock.MarshalizerMock{} storer := &storageStubs.StorerStub{ GetCalled: func(key []byte) ([]byte, error) { supplyESDT := &SupplyESDT{ diff --git a/dblookupext/esdtSupply/nonceProcessor_test.go b/dblookupext/esdtSupply/nonceProcessor_test.go index f04085ca64a..f7bd2377224 100644 --- a/dblookupext/esdtSupply/nonceProcessor_test.go +++ b/dblookupext/esdtSupply/nonceProcessor_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/require" ) @@ -12,7 +12,7 @@ import ( func TestNonceProcessor_shouldProcessLogs_currentNonceLowerThanProcessed(t *testing.T) { t.Parallel() - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} nonceProc := newNonceProcessor(marshalizer, &storageStubs.StorerStub{ GetCalled: func(key []byte) ([]byte, error) { processedBlockNonce := &ProcessedBlockNonce{ @@ -30,7 +30,7 @@ func TestNonceProcessor_shouldProcessLogs_currentNonceLowerThanProcessed(t *test func TestNonceProcessor_shouldProcessLogs_currentNonceHigherThanProcessed(t *testing.T) { t.Parallel() - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} nonceProc := newNonceProcessor(marshalizer, &storageStubs.StorerStub{ GetCalled: func(key []byte) ([]byte, error) { processedBlockNonce := &ProcessedBlockNonce{ @@ -48,7 +48,7 @@ func TestNonceProcessor_shouldProcessLogs_currentNonceHigherThanProcessed(t *tes func TestNonceProcessor_shouldProcessLogs_nothingInStorageShouldProcess(t *testing.T) { t.Parallel() - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} nonceProc := newNonceProcessor(marshalizer, &storageStubs.StorerStub{ GetCalled: func(key []byte) ([]byte, error) { return nil, storage.ErrKeyNotFound @@ -63,7 +63,7 @@ func TestNonceProcessor_shouldProcessLogs_nothingInStorageShouldProcess(t *testi func TestNonceProcessor_shouldProcessLogs_revertNothingInStorage(t *testing.T) { t.Parallel() - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} nonceProc := newNonceProcessor(marshalizer, &storageStubs.StorerStub{ GetCalled: func(key []byte) ([]byte, error) { return nil, storage.ErrKeyNotFound @@ -78,7 +78,7 @@ func TestNonceProcessor_shouldProcessLogs_revertNothingInStorage(t *testing.T) { func TestNonceProcessor_saveNonceInStorage(t *testing.T) { t.Parallel() - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} nonceProc := newNonceProcessor(marshalizer, &storageStubs.StorerStub{ PutCalled: func(key, data []byte) error { require.Equal(t, []byte(processedBlockKey), key) diff --git a/debug/process/stateExport.go b/debug/process/stateExport.go index 831aaebfc0e..421c52edd9d 100644 --- a/debug/process/stateExport.go +++ b/debug/process/stateExport.go @@ -46,7 +46,7 @@ func getCodeAndData(accountsDB state.AccountsAdapter, address []byte) (code []by rootHash := userAccount.GetRootHash() if len(rootHash) > 0 { - csvHexedData, err = getData(accountsDB, rootHash, address) + csvHexedData, err = getData(userAccount) if err != nil { return nil, nil, err } @@ -64,36 +64,29 @@ func getCode(accountsDB state.AccountsAdapter, codeHash []byte) ([]byte, error) return code, nil } -func getData(accountsDB state.AccountsAdapter, rootHash []byte, address []byte) ([]string, error) { +func getData(account state.UserAccountHandler) ([]string, error) { leavesChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), ErrChan: errChan.NewErrChanWrapper(), } - err := accountsDB.GetAllLeaves(leavesChannels, context.Background(), rootHash) + err := account.GetAllLeaves(leavesChannels, context.Background()) if err != nil { return nil, fmt.Errorf("%w while trying to export data on hex root hash %s, address %s", - err, hex.EncodeToString(rootHash), hex.EncodeToString(address)) + err, hex.EncodeToString(account.GetRootHash()), hex.EncodeToString(account.AddressBytes())) } lines := make([]string, 0) for keyVal := range leavesChannels.LeavesChan { - suffix := append(keyVal.Key(), address...) - valWithoutSuffix, errTrim := keyVal.ValueWithoutSuffix(suffix) - if errTrim != nil { - return nil, fmt.Errorf("%w while trying to export data on hex root hash %s, address %s", - errTrim, hex.EncodeToString(rootHash), hex.EncodeToString(address)) - } - lines = append(lines, fmt.Sprintf("%s,%s", hex.EncodeToString(keyVal.Key()), - hex.EncodeToString(valWithoutSuffix))) + hex.EncodeToString(keyVal.Value()))) } err = leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, fmt.Errorf("%w while trying to export data on hex root hash %s, address %s", - err, hex.EncodeToString(rootHash), hex.EncodeToString(address)) + err, hex.EncodeToString(account.GetRootHash()), hex.EncodeToString(account.AddressBytes())) } return lines, nil diff --git a/epochStart/bootstrap/baseStorageHandler_test.go b/epochStart/bootstrap/baseStorageHandler_test.go index 4cb1474b925..2c1b294f629 100644 --- a/epochStart/bootstrap/baseStorageHandler_test.go +++ b/epochStart/bootstrap/baseStorageHandler_test.go @@ -6,15 +6,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-go/dataRetriever" - "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSaveMiniBlocksFromComponents(t *testing.T) { - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} mb1 := &block.MiniBlock{ diff --git a/epochStart/bootstrap/disabled/disabledAccountsAdapter.go b/epochStart/bootstrap/disabled/disabledAccountsAdapter.go index 49e328ce8e5..753ee93fd14 100644 --- a/epochStart/bootstrap/disabled/disabledAccountsAdapter.go +++ b/epochStart/bootstrap/disabled/disabledAccountsAdapter.go @@ -118,7 +118,7 @@ func (a *accountsAdapter) ClosePersister() error { } // GetAllLeaves - -func (a *accountsAdapter) GetAllLeaves(_ *common.TrieIteratorChannels, _ context.Context, _ []byte) error { +func (a *accountsAdapter) GetAllLeaves(_ *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { return nil } diff --git a/epochStart/bootstrap/process.go b/epochStart/bootstrap/process.go index 743250c5bde..aae8cc137ff 100644 --- a/epochStart/bootstrap/process.go +++ b/epochStart/bootstrap/process.go @@ -21,7 +21,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" factoryDataPool "github.com/multiversx/mx-chain-go/dataRetriever/factory" "github.com/multiversx/mx-chain-go/dataRetriever/factory/containers" - "github.com/multiversx/mx-chain-go/dataRetriever/factory/requestersContainer" + requesterscontainer "github.com/multiversx/mx-chain-go/dataRetriever/factory/requestersContainer" "github.com/multiversx/mx-chain-go/dataRetriever/factory/resolverscontainer" "github.com/multiversx/mx-chain-go/dataRetriever/requestHandlers" "github.com/multiversx/mx-chain-go/epochStart" @@ -1083,6 +1083,7 @@ func (e *epochStartBootstrap) syncUserAccountsState(rootHash []byte) error { CheckNodesOnDisk: e.checkNodesOnDisk, UserAccountsSyncStatisticsHandler: e.trieSyncStatisticsProvider, AppStatusHandler: e.statusHandler, + EnableEpochsHandler: e.coreComponentsHolder.EnableEpochsHandler(), }, ShardId: e.shardCoordinator.SelfId(), Throttler: thr, @@ -1155,6 +1156,7 @@ func (e *epochStartBootstrap) syncValidatorAccountsState(rootHash []byte) error CheckNodesOnDisk: e.checkNodesOnDisk, UserAccountsSyncStatisticsHandler: statistics.NewTrieSyncStatistics(), AppStatusHandler: disabledCommon.NewAppStatusHandler(), + EnableEpochsHandler: e.coreComponentsHolder.EnableEpochsHandler(), }, } accountsDBSyncer, err := syncer.NewValidatorAccountsSyncer(argsValidatorAccountsSyncer) diff --git a/epochStart/bootstrap/process_test.go b/epochStart/bootstrap/process_test.go index 6ac7b8899ee..159f069874a 100644 --- a/epochStart/bootstrap/process_test.go +++ b/epochStart/bootstrap/process_test.go @@ -32,9 +32,11 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/scheduledDataSyncer" @@ -77,7 +79,7 @@ func createComponentsForEpochStart() (*mock.CoreComponentsMock, *mock.CryptoComp NodeTypeProviderField: &nodeTypeProviderMock.NodeTypeProviderStub{}, ProcessStatusHandlerInstance: &testscommon.ProcessStatusHandlerStub{}, HardforkTriggerPubKeyField: []byte("provided hardfork pub key"), - EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, &mock.CryptoComponentsMock{ PubKey: &cryptoMocks.PublicKeyStub{}, @@ -1242,7 +1244,7 @@ func TestRequestAndProcessForShard_ShouldFail(t *testing.T) { expectedErr := errors.New("expected error") coreComp, cryptoComp := createComponentsForEpochStart() - coreComp.IntMarsh = &testscommon.MarshalizerStub{ + coreComp.IntMarsh = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, @@ -1376,7 +1378,7 @@ func TestRequestAndProcessForMeta_ShouldFail(t *testing.T) { expectedErr := errors.New("expected error") coreComp, cryptoComp := createComponentsForEpochStart() - coreComp.IntMarsh = &testscommon.MarshalizerStub{ + coreComp.IntMarsh = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, diff --git a/epochStart/bootstrap/shardStorageHandler_test.go b/epochStart/bootstrap/shardStorageHandler_test.go index 45c1ee48abd..a4566775be3 100644 --- a/epochStart/bootstrap/shardStorageHandler_test.go +++ b/epochStart/bootstrap/shardStorageHandler_test.go @@ -28,6 +28,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" epochStartMocks "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks/epochStart" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" @@ -985,7 +986,7 @@ func TestShardStorageHandler_saveLastCrossNotarizedHeadersWithoutScheduledErrorW args := createDefaultShardStorageArgs() expectedErr := fmt.Errorf("expected error") // Simulate an error when writing to storage with a mock marshaller - args.marshalizer = &testscommon.MarshalizerStub{MarshalCalled: func(obj interface{}) ([]byte, error) { + args.marshalizer = &marshallerMock.MarshalizerStub{MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }} shardStorage, _ := NewShardStorageHandler( diff --git a/epochStart/bootstrap/storageProcess.go b/epochStart/bootstrap/storageProcess.go index 1b4578c88c9..ac902d0c1f3 100644 --- a/epochStart/bootstrap/storageProcess.go +++ b/epochStart/bootstrap/storageProcess.go @@ -253,6 +253,7 @@ func (sesb *storageEpochStartBootstrap) createStorageRequesters() error { ManualEpochStartNotifier: mesn, ChanGracefullyClose: sesb.chanGracefullyClose, SnapshotsEnabled: sesb.flagsConfig.SnapshotsEnabled, + EnableEpochsHandler: sesb.coreComponentsHolder.EnableEpochsHandler(), } var requestersContainerFactory dataRetriever.RequestersContainerFactory diff --git a/epochStart/bootstrap/syncValidatorStatus_test.go b/epochStart/bootstrap/syncValidatorStatus_test.go index 01ddc3c4c38..f7e409af875 100644 --- a/epochStart/bootstrap/syncValidatorStatus_test.go +++ b/epochStart/bootstrap/syncValidatorStatus_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" epochStartMocks "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks/epochStart" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" @@ -306,6 +307,6 @@ func getSyncValidatorStatusArgs() ArgsNewSyncValidatorStatus { ChanNodeStop: endProcess.GetDummyEndProcessChannel(), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, IsFullArchive: false, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } diff --git a/epochStart/metachain/baseRewards_test.go b/epochStart/metachain/baseRewards_test.go index 57fd8ad7a9b..44bea20759b 100644 --- a/epochStart/metachain/baseRewards_test.go +++ b/epochStart/metachain/baseRewards_test.go @@ -22,7 +22,9 @@ import ( "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -833,8 +835,13 @@ func TestBaseRewardsCreator_isSystemDelegationSC(t *testing.T) { isDelegationSCAddress = rwd.isSystemDelegationSC(peerAccount.AddressBytes()) require.False(t, isDelegationSCAddress) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } // existing user account - userAccount, err := state.NewUserAccount([]byte("userAddress")) + userAccount, err := state.NewUserAccount([]byte("userAddress"), argsAccCreation) require.Nil(t, err) userAccount.SetDataTrie(&trieMock.TrieStub{ @@ -1141,7 +1148,16 @@ func getBaseRewardsArguments() BaseRewardsCreatorArgs { storageManagerArgs.Hasher = hasher trieFactoryManager, _ := trie.CreateTrieStorageManager(storageManagerArgs, storage.GetStorageManagerOptions()) - userAccountsDB := createAccountsDB(hasher, marshalizer, factory.NewAccountCreator(), trieFactoryManager) + argsAccCreator := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshalizer, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + accCreator, _ := factory.NewAccountCreator(argsAccCreator) + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + SwitchJailWaitingEnableEpochField: 0, + } + userAccountsDB := createAccountsDB(hasher, marshalizer, accCreator, trieFactoryManager, enableEpochsHandler) shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) shardCoordinator.CurrentShard = core.MetachainShardId shardCoordinator.ComputeIdCalled = func(address []byte) uint32 { @@ -1165,10 +1181,8 @@ func getBaseRewardsArguments() BaseRewardsCreatorArgs { return 63 }, }, - UserAccountsDB: userAccountsDB, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ - SwitchJailWaitingEnableEpochField: 0, - }, + UserAccountsDB: userAccountsDB, + EnableEpochsHandler: enableEpochsHandler, } } diff --git a/epochStart/metachain/epochStartData_test.go b/epochStart/metachain/epochStartData_test.go index 66a99ed2b9b..030bfb93a8c 100644 --- a/epochStart/metachain/epochStartData_test.go +++ b/epochStart/metachain/epochStartData_test.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -75,7 +76,7 @@ func createMockEpochStartCreatorArguments() ArgsNewEpochStartData { ShardCoordinator: shardCoordinator, EpochStartTrigger: &mock.EpochStartTriggerStub{}, RequestHandler: &testscommon.RequestHandlerStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } return argsNewEpochStartData } @@ -706,7 +707,7 @@ func Test_setIndexOfFirstAndLastTxProcessedShouldNotSetReserved(t *testing.T) { partialExecutionEnableEpoch := uint32(5) arguments := createMockEpochStartCreatorArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ MiniBlockPartialExecutionEnableEpochField: partialExecutionEnableEpoch, } arguments.EpochStartTrigger = &mock.EpochStartTriggerStub{ @@ -732,7 +733,7 @@ func Test_setIndexOfFirstAndLastTxProcessedShouldSetReserved(t *testing.T) { partialExecutionEnableEpoch := uint32(5) arguments := createMockEpochStartCreatorArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ MiniBlockPartialExecutionEnableEpochField: partialExecutionEnableEpoch, } arguments.EpochStartTrigger = &mock.EpochStartTriggerStub{ diff --git a/epochStart/metachain/rewardsCreatorProxy_test.go b/epochStart/metachain/rewardsCreatorProxy_test.go index 78483cec9e7..5f160297e1f 100644 --- a/epochStart/metachain/rewardsCreatorProxy_test.go +++ b/epochStart/metachain/rewardsCreatorProxy_test.go @@ -14,8 +14,8 @@ import ( "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/stretchr/testify/require" ) @@ -101,7 +101,7 @@ func TestRewardsCreatorProxy_CreateRewardsMiniBlocksWithSwitchToRewardsCreatorV2 } rewardsCreatorProxy, vInfo, metaBlock := createTestData(rewardCreatorV1, rCreatorV1) - stub, _ := rewardsCreatorProxy.args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + stub, _ := rewardsCreatorProxy.args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) stub.StakingV2EnableEpochField = 1 metaBlock.Epoch = 3 economics := &metaBlock.EpochStart.Economics @@ -128,7 +128,7 @@ func TestRewardsCreatorProxy_CreateRewardsMiniBlocksWithSwitchToRewardsCreatorV1 } rewardsCreatorProxy, vInfo, metaBlock := createTestData(rewardCreatorV2, rCreatorV2) - stub, _ := rewardsCreatorProxy.args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + stub, _ := rewardsCreatorProxy.args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) stub.StakingV2EnableEpochField = 5 metaBlock.Epoch = 3 economics := &metaBlock.EpochStart.Economics diff --git a/epochStart/metachain/systemSCs.go b/epochStart/metachain/systemSCs.go index 645f54ce3ea..d0138f52c88 100644 --- a/epochStart/metachain/systemSCs.go +++ b/epochStart/metachain/systemSCs.go @@ -3,7 +3,6 @@ package metachain import ( "bytes" "context" - "encoding/hex" "fmt" "math" "math/big" @@ -24,7 +23,6 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -1096,27 +1094,18 @@ func (s *systemSCProcessor) getValidatorSystemAccount() (state.UserAccountHandle func (s *systemSCProcessor) getArgumentsForSetOwnerFunctionality(userValidatorAccount state.UserAccountHandler) ([][]byte, error) { arguments := make([][]byte, 0) - rootHash, err := userValidatorAccount.DataTrie().RootHash() - if err != nil { - return nil, err - } - leavesChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = userValidatorAccount.DataTrie().GetAllLeavesOnChannel(leavesChannels, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) + err := userValidatorAccount.GetAllLeaves(leavesChannels, context.Background()) if err != nil { return nil, err } for leaf := range leavesChannels.LeavesChan { validatorData := &systemSmartContracts.ValidatorDataV2{} - value, errTrim := leaf.ValueWithoutSuffix(append(leaf.Key(), vm.ValidatorSCAddress...)) - if errTrim != nil { - return nil, fmt.Errorf("%w for validator key %s", errTrim, hex.EncodeToString(leaf.Key())) - } - err = s.marshalizer.Unmarshal(validatorData, value) + err = s.marshalizer.Unmarshal(validatorData, leaf.Value()) if err != nil { continue } diff --git a/epochStart/metachain/systemSCs_test.go b/epochStart/metachain/systemSCs_test.go index 8b48dc948f7..59e41daf56b 100644 --- a/epochStart/metachain/systemSCs_test.go +++ b/epochStart/metachain/systemSCs_test.go @@ -44,6 +44,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" @@ -883,8 +884,9 @@ func createAccountsDB( marshaller marshal.Marshalizer, accountFactory state.AccountFactory, trieStorageManager common.StorageManager, + enableEpochsHandler common.EnableEpochsHandler, ) *state.AccountsDB { - tr, _ := trie.NewTrie(trieStorageManager, marshaller, hasher, 5) + tr, _ := trie.NewTrie(trieStorageManager, marshaller, hasher, enableEpochsHandler, 5) ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, HashesSize: 10000, @@ -917,13 +919,21 @@ func createFullArgumentsForSystemSCProcessing(enableEpochsConfig config.EnableEp storageManagerArgs.CheckpointsStorer = trieStorer trieFactoryManager, _ := trie.CreateTrieStorageManager(storageManagerArgs, stateMock.GetStorageManagerOptions()) - userAccountsDB := createAccountsDB(hasher, marshalizer, factory.NewAccountCreator(), trieFactoryManager) - peerAccountsDB := createAccountsDB(hasher, marshalizer, factory.NewPeerAccountCreator(), trieFactoryManager) + argsAccCreator := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshalizer, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + accCreator, _ := factory.NewAccountCreator(argsAccCreator) + peerAccCreator := factory.NewPeerAccountCreator() en := forking.NewGenericEpochNotifier() epochsConfig := &config.EpochConfig{ EnableEpochs: enableEpochsConfig, } enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(epochsConfig.EnableEpochs, en) + userAccountsDB := createAccountsDB(hasher, marshalizer, accCreator, trieFactoryManager, enableEpochsHandler) + peerAccountsDB := createAccountsDB(hasher, marshalizer, peerAccCreator, trieFactoryManager, enableEpochsHandler) + argsValidatorsProcessor := peer.ArgValidatorStatisticsProcessor{ Marshalizer: marshalizer, NodesCoordinator: &shardingMocks.NodesCoordinatorStub{}, @@ -1113,7 +1123,7 @@ func createEconomicsData() process.EconomicsDataHandler { }, }, EpochNotifier: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, BuiltInFunctionsCostHandler: &mock.BuiltInCostHandlerStub{}, TxVersionChecker: &testscommon.TxVersionCheckerStub{}, } @@ -1582,7 +1592,7 @@ func TestSystemSCProcessor_ProcessSystemSmartContractUnStakeFromDelegationContra assert.Equal(t, 4, len(validatorInfos[0])) delegationSC := loadSCAccount(args.UserAccountsDB, delegationAddr) - marshalledData, _, err := delegationSC.DataTrie().(common.Trie).Get([]byte("delegationStatus")) + marshalledData, _, err := delegationSC.RetrieveValue([]byte("delegationStatus")) assert.Nil(t, err) dStatus := &systemSmartContracts.DelegationContractStatus{ StakedKeys: make([]*systemSmartContracts.NodesData, 0), @@ -1671,7 +1681,7 @@ func TestSystemSCProcessor_ProcessSystemSmartContractShouldUnStakeFromAdditional } delegationSC := loadSCAccount(args.UserAccountsDB, delegationAddr) - marshalledData, _, err := delegationSC.DataTrie().(common.Trie).Get([]byte("delegationStatus")) + marshalledData, _, err := delegationSC.RetrieveValue([]byte("delegationStatus")) assert.Nil(t, err) dStatus := &systemSmartContracts.DelegationContractStatus{ StakedKeys: make([]*systemSmartContracts.NodesData, 0), @@ -1761,7 +1771,7 @@ func TestSystemSCProcessor_ProcessSystemSmartContractUnStakeFromAdditionalQueue( assert.Nil(t, err) delegationSC := loadSCAccount(args.UserAccountsDB, delegationAddr2) - marshalledData, _, err := delegationSC.DataTrie().(common.Trie).Get([]byte("delegationStatus")) + marshalledData, _, err := delegationSC.RetrieveValue([]byte("delegationStatus")) assert.Nil(t, err) dStatus := &systemSmartContracts.DelegationContractStatus{ StakedKeys: make([]*systemSmartContracts.NodesData, 0), @@ -1846,14 +1856,14 @@ func TestSystemSCProcessor_TogglePauseUnPause(t *testing.T) { assert.Nil(t, err) validatorSC := loadSCAccount(s.userAccountsDB, vm.ValidatorSCAddress) - value, _, _ := validatorSC.DataTrie().(common.Trie).Get([]byte("unStakeUnBondPause")) + value, _, _ := validatorSC.RetrieveValue([]byte("unStakeUnBondPause")) assert.True(t, value[0] == 1) err = s.ToggleUnStakeUnBond(false) assert.Nil(t, err) validatorSC = loadSCAccount(s.userAccountsDB, vm.ValidatorSCAddress) - value, _, _ = validatorSC.DataTrie().(common.Trie).Get([]byte("unStakeUnBondPause")) + value, _, _ = validatorSC.RetrieveValue([]byte("unStakeUnBondPause")) assert.True(t, value[0] == 0) } diff --git a/epochStart/metachain/validators_test.go b/epochStart/metachain/validators_test.go index d64093ee929..5fc1192d5d9 100644 --- a/epochStart/metachain/validators_test.go +++ b/epochStart/metachain/validators_test.go @@ -19,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" vics "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" "github.com/stretchr/testify/assert" @@ -131,7 +132,7 @@ func createMockEpochValidatorInfoCreatorsArguments() ArgsNewValidatorInfoCreator return &vics.ValidatorInfoCacherStub{} }, }, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, }, } @@ -569,7 +570,7 @@ func TestEpochValidatorInfoCreator_GetShardValidatorInfoData(t *testing.T) { t.Parallel() arguments := createMockEpochValidatorInfoCreatorsArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: false, } vic, _ := NewValidatorInfoCreator(arguments) @@ -586,7 +587,7 @@ func TestEpochValidatorInfoCreator_GetShardValidatorInfoData(t *testing.T) { t.Parallel() arguments := createMockEpochValidatorInfoCreatorsArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, } vic, _ := NewValidatorInfoCreator(arguments) @@ -607,7 +608,7 @@ func TestEpochValidatorInfoCreator_CreateMarshalledData(t *testing.T) { t.Parallel() arguments := createMockEpochValidatorInfoCreatorsArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: false, } vic, _ := NewValidatorInfoCreator(arguments) @@ -621,7 +622,7 @@ func TestEpochValidatorInfoCreator_CreateMarshalledData(t *testing.T) { t.Parallel() arguments := createMockEpochValidatorInfoCreatorsArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, } vic, _ := NewValidatorInfoCreator(arguments) @@ -634,7 +635,7 @@ func TestEpochValidatorInfoCreator_CreateMarshalledData(t *testing.T) { t.Parallel() arguments := createMockEpochValidatorInfoCreatorsArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, } vic, _ := NewValidatorInfoCreator(arguments) @@ -648,7 +649,7 @@ func TestEpochValidatorInfoCreator_CreateMarshalledData(t *testing.T) { t.Parallel() arguments := createMockEpochValidatorInfoCreatorsArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, } vic, _ := NewValidatorInfoCreator(arguments) @@ -662,7 +663,7 @@ func TestEpochValidatorInfoCreator_CreateMarshalledData(t *testing.T) { t.Parallel() arguments := createMockEpochValidatorInfoCreatorsArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, } arguments.DataPool = &dataRetrieverMock.PoolsHolderStub{ @@ -695,7 +696,7 @@ func TestEpochValidatorInfoCreator_CreateMarshalledData(t *testing.T) { svi3 := &state.ShardValidatorInfo{PublicKey: []byte("z")} marshalledSVI3, _ := arguments.Marshalizer.Marshal(svi3) - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, } arguments.DataPool = &dataRetrieverMock.PoolsHolderStub{ @@ -739,7 +740,7 @@ func TestEpochValidatorInfoCreator_SetMarshalledValidatorInfoTxsShouldWork(t *te svi2 := &state.ShardValidatorInfo{PublicKey: []byte("y")} marshalledSVI2, _ := arguments.Marshalizer.Marshal(svi2) - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, } arguments.DataPool = &dataRetrieverMock.PoolsHolderStub{ @@ -776,7 +777,7 @@ func TestEpochValidatorInfoCreator_GetValidatorInfoTxsShouldWork(t *testing.T) { svi2 := &state.ShardValidatorInfo{PublicKey: []byte("y")} svi3 := &state.ShardValidatorInfo{PublicKey: []byte("z")} - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, } arguments.DataPool = &dataRetrieverMock.PoolsHolderStub{ @@ -818,7 +819,7 @@ func TestEpochValidatorInfoCreator_SetMapShardValidatorInfoShouldWork(t *testing svi1 := &state.ShardValidatorInfo{PublicKey: []byte("x")} svi2 := &state.ShardValidatorInfo{PublicKey: []byte("y")} - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, } arguments.DataPool = &dataRetrieverMock.PoolsHolderStub{ @@ -858,7 +859,7 @@ func TestEpochValidatorInfoCreator_GetShardValidatorInfoShouldWork(t *testing.T) svi := &state.ShardValidatorInfo{PublicKey: []byte("x")} marshalledSVI, _ := arguments.Marshalizer.Marshal(svi) - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: false, } arguments.DataPool = &dataRetrieverMock.PoolsHolderStub{ @@ -886,7 +887,7 @@ func TestEpochValidatorInfoCreator_GetShardValidatorInfoShouldWork(t *testing.T) svi := &state.ShardValidatorInfo{PublicKey: []byte("x")} - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, } arguments.DataPool = &dataRetrieverMock.PoolsHolderStub{ diff --git a/epochStart/shardchain/trigger_test.go b/epochStart/shardchain/trigger_test.go index 5959faaa9a0..3013fac8c13 100644 --- a/epochStart/shardchain/trigger_test.go +++ b/epochStart/shardchain/trigger_test.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -64,7 +65,7 @@ func createMockShardEpochStartTriggerArguments() *ArgsShardEpochStartTrigger { PeerMiniBlocksSyncer: &mock.ValidatorInfoSyncerStub{}, RoundHandler: &mock.RoundHandlerStub{}, AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } diff --git a/errors/errors.go b/errors/errors.go index 131f93f2b72..aa88cd55d99 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -552,3 +552,9 @@ var ErrNilHistoryRepository = errors.New("history repository is nil") // ErrNilMissingTrieNodesNotifier signals that a nil missing trie nodes notifier was provided var ErrNilMissingTrieNodesNotifier = errors.New("nil missing trie nodes notifier") + +// ErrInvalidTrieNodeVersion signals that an invalid trie node version has been provided +var ErrInvalidTrieNodeVersion = errors.New("invalid trie node version") + +// ErrNilTrieMigrator signals that a nil trie migrator has been provided +var ErrNilTrieMigrator = errors.New("nil trie migrator") diff --git a/facade/initial/initialNodeFacade.go b/facade/initial/initialNodeFacade.go index cd268eacf0a..05ebfd68ba9 100644 --- a/facade/initial/initialNodeFacade.go +++ b/facade/initial/initialNodeFacade.go @@ -400,6 +400,11 @@ func (inf *initialNodeFacade) GetGasConfigs() (map[string]map[string]uint64, err return nil, errNodeStarting } +// IsDataTrieMigrated returns false and error +func (inf *initialNodeFacade) IsDataTrieMigrated(_ string, _ api.AccountQueryOptions) (bool, error) { + return false, errNodeStarting +} + // IsInterfaceNil returns true if there is no value under the interface func (inf *initialNodeFacade) IsInterfaceNil() bool { return inf == nil diff --git a/facade/initial/initialNodeFacade_test.go b/facade/initial/initialNodeFacade_test.go index bba4b57eaa7..70ebf524359 100644 --- a/facade/initial/initialNodeFacade_test.go +++ b/facade/initial/initialNodeFacade_test.go @@ -232,6 +232,10 @@ func TestInitialNodeFacade_AllMethodsShouldNotPanic(t *testing.T) { assert.Equal(t, api.GuardianData{}, guardianData) assert.Equal(t, errNodeStarting, err) + isMigrated, err := inf.IsDataTrieMigrated("", api.AccountQueryOptions{}) + assert.False(t, isMigrated) + assert.Equal(t, errNodeStarting, err) + mainTrieResponse, dataTrieResponse, err := inf.GetProofDataTrie("", "", "") assert.Nil(t, mainTrieResponse) assert.Nil(t, dataTrieResponse) diff --git a/facade/interface.go b/facade/interface.go index f965c946ac3..d3d7f883fc8 100644 --- a/facade/interface.go +++ b/facade/interface.go @@ -100,6 +100,7 @@ type NodeHandler interface { GetProof(rootHash string, key string) (*common.GetProofResponse, error) GetProofDataTrie(rootHash string, address string, key string) (*common.GetProofResponse, *common.GetProofResponse, error) VerifyProof(rootHash string, address string, proof [][]byte) (bool, error) + IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) } // TransactionSimulatorProcessor defines the actions which a transaction simulator processor has to implement diff --git a/facade/mock/nodeStub.go b/facade/mock/nodeStub.go index 4c69d0e2790..f4b4d643ebf 100644 --- a/facade/mock/nodeStub.go +++ b/facade/mock/nodeStub.go @@ -53,6 +53,7 @@ type NodeStub struct { GetProofDataTrieCalled func(rootHash string, address string, key string) (*common.GetProofResponse, *common.GetProofResponse, error) VerifyProofCalled func(rootHash string, address string, proof [][]byte) (bool, error) GetTokenSupplyCalled func(token string) (*api.ESDTSupply, error) + IsDataTrieMigratedCalled func(address string, options api.AccountQueryOptions) (bool, error) } // GetProof - @@ -284,6 +285,14 @@ func (ns *NodeStub) GetAllIssuedESDTs(tokenType string, ctx context.Context) ([] return make([]string, 0), nil } +// IsDataTrieMigrated - +func (ns *NodeStub) IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) { + if ns.IsDataTrieMigratedCalled != nil { + return ns.IsDataTrieMigratedCalled(address, options) + } + return false, nil +} + // GetNFTTokenIDsRegisteredByAddress - func (ns *NodeStub) GetNFTTokenIDsRegisteredByAddress(address string, options api.AccountQueryOptions, ctx context.Context) ([]string, api.BlockInfo, error) { if ns.GetNFTTokenIDsRegisteredByAddressCalled != nil { diff --git a/facade/nodeFacade.go b/facade/nodeFacade.go index 649e5f9f3bf..c577b882e7d 100644 --- a/facade/nodeFacade.go +++ b/facade/nodeFacade.go @@ -585,6 +585,11 @@ func (nf *nodeFacade) VerifyProof(rootHash string, address string, proof [][]byt return nf.node.VerifyProof(rootHash, address, proof) } +// IsDataTrieMigrated returns true if the data trie for the given address is migrated +func (nf *nodeFacade) IsDataTrieMigrated(address string, options apiData.AccountQueryOptions) (bool, error) { + return nf.node.IsDataTrieMigrated(address, options) +} + func (nf *nodeFacade) convertVmOutputToApiResponse(input *vmcommon.VMOutput) *vm.VMOutputApi { outputAccounts := make(map[string]*vm.OutputAccountApi) for key, acc := range input.OutputAccounts { diff --git a/facade/nodeFacade_test.go b/facade/nodeFacade_test.go index 1f68c7c5108..71505fa6ddc 100644 --- a/facade/nodeFacade_test.go +++ b/facade/nodeFacade_test.go @@ -31,6 +31,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1188,6 +1189,59 @@ func TestNodeFacade_VerifyProof(t *testing.T) { require.True(t, response) } +func TestNodeFacade_IsDataTrieMigrated(t *testing.T) { + t.Parallel() + + t.Run("should return false if trie is not migrated", func(t *testing.T) { + t.Parallel() + + arg := createMockArguments() + arg.Node = &mock.NodeStub{ + IsDataTrieMigratedCalled: func(_ string, _ api.AccountQueryOptions) (bool, error) { + return false, nil + }, + } + nf, _ := NewNodeFacade(arg) + + isMigrated, err := nf.IsDataTrieMigrated("address", api.AccountQueryOptions{}) + assert.Nil(t, err) + assert.False(t, isMigrated) + }) + + t.Run("should return true if trie is migrated", func(t *testing.T) { + t.Parallel() + + arg := createMockArguments() + arg.Node = &mock.NodeStub{ + IsDataTrieMigratedCalled: func(_ string, _ api.AccountQueryOptions) (bool, error) { + return true, nil + }, + } + nf, _ := NewNodeFacade(arg) + + isMigrated, err := nf.IsDataTrieMigrated("address", api.AccountQueryOptions{}) + assert.Nil(t, err) + assert.True(t, isMigrated) + }) + + t.Run("should return error if node returns err", func(t *testing.T) { + t.Parallel() + + expectedErr := fmt.Errorf(" expected error") + arg := createMockArguments() + arg.Node = &mock.NodeStub{ + IsDataTrieMigratedCalled: func(_ string, _ api.AccountQueryOptions) (bool, error) { + return false, expectedErr + }, + } + nf, _ := NewNodeFacade(arg) + + isMigrated, err := nf.IsDataTrieMigrated("address", api.AccountQueryOptions{}) + assert.Equal(t, expectedErr, err) + assert.False(t, isMigrated) + }) +} + func TestNodeFacade_ExecuteSCQuery(t *testing.T) { t.Parallel() diff --git a/factory/api/apiResolverFactory_test.go b/factory/api/apiResolverFactory_test.go index 3a8288cdf1d..afb98acecb4 100644 --- a/factory/api/apiResolverFactory_test.go +++ b/factory/api/apiResolverFactory_test.go @@ -22,10 +22,12 @@ import ( componentsMock "github.com/multiversx/mx-chain-go/testscommon/components" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/guardianMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMocks "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/stretchr/testify/require" ) @@ -296,9 +298,9 @@ func createMockSCQueryElementArgs() api.SCQueryElementArgs { return []byte(humanReadable), nil }, }, - IntMarsh: &testscommon.MarshalizerStub{}, + IntMarsh: &marshallerMock.MarshalizerStub{}, EpochChangeNotifier: &epochNotifierMock.EpochNotifierStub{}, - EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, UInt64ByteSliceConv: &testsMocks.Uint64ByteSliceConverterMock{}, EconomicsHandler: &economicsmocks.EconomicsHandlerStub{}, NodesConfig: &testscommon.NodesSetupStub{}, diff --git a/factory/bootstrap/shardingFactory_test.go b/factory/bootstrap/shardingFactory_test.go index 277589f2e7e..0df777933b0 100644 --- a/factory/bootstrap/shardingFactory_test.go +++ b/factory/bootstrap/shardingFactory_test.go @@ -16,6 +16,8 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -194,7 +196,7 @@ func TestCreateNodesCoordinator(t *testing.T) { config.PreferencesConfig{}, &mock.EpochStartNotifierStub{}, &cryptoMocks.PublicKeyStub{}, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -204,7 +206,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.Equal(t, errErd.ErrNilShuffleOutCloser, err) @@ -219,7 +221,7 @@ func TestCreateNodesCoordinator(t *testing.T) { config.PreferencesConfig{}, &mock.EpochStartNotifierStub{}, &cryptoMocks.PublicKeyStub{}, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -229,7 +231,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.Equal(t, errErd.ErrNilGenesisNodesSetupHandler, err) @@ -244,7 +246,7 @@ func TestCreateNodesCoordinator(t *testing.T) { config.PreferencesConfig{}, nil, &cryptoMocks.PublicKeyStub{}, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -254,7 +256,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.Equal(t, errErd.ErrNilEpochStartNotifier, err) @@ -269,7 +271,7 @@ func TestCreateNodesCoordinator(t *testing.T) { config.PreferencesConfig{}, &mock.EpochStartNotifierStub{}, nil, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -279,7 +281,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.Equal(t, errErd.ErrNilPublicKey, err) @@ -294,7 +296,7 @@ func TestCreateNodesCoordinator(t *testing.T) { config.PreferencesConfig{}, &mock.EpochStartNotifierStub{}, &cryptoMocks.PublicKeyStub{}, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -304,7 +306,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.Equal(t, errErd.ErrNilBootstrapParamsHandler, err) @@ -319,7 +321,7 @@ func TestCreateNodesCoordinator(t *testing.T) { config.PreferencesConfig{}, &mock.EpochStartNotifierStub{}, &cryptoMocks.PublicKeyStub{}, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -329,7 +331,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, nil, &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.Equal(t, nodesCoordinator.ErrNilNodeStopChannel, err) @@ -346,7 +348,7 @@ func TestCreateNodesCoordinator(t *testing.T) { }, &mock.EpochStartNotifierStub{}, &cryptoMocks.PublicKeyStub{}, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -356,7 +358,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.NotNil(t, err) @@ -377,7 +379,7 @@ func TestCreateNodesCoordinator(t *testing.T) { return nil, expectedErr }, }, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -387,7 +389,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.True(t, errors.Is(err, expectedErr)) @@ -408,7 +410,7 @@ func TestCreateNodesCoordinator(t *testing.T) { return nil, expectedErr }, }, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -418,7 +420,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.True(t, errors.Is(err, expectedErr)) @@ -439,7 +441,7 @@ func TestCreateNodesCoordinator(t *testing.T) { return nil, nil // no error but nil pub key to force NewShuffledOutTrigger to fail }, }, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -449,7 +451,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.NotNil(t, err) @@ -480,7 +482,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.NotNil(t, err) @@ -501,7 +503,7 @@ func TestCreateNodesCoordinator(t *testing.T) { return []byte("public key"), nil }, }, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, nil, // force NewIndexHashedNodesCoordinatorWithRater to fail &storage.StorerStub{}, @@ -532,7 +534,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.NotNil(t, err) @@ -553,7 +555,7 @@ func TestCreateNodesCoordinator(t *testing.T) { return []byte("public key"), nil }, }, - &testscommon.MarshalizerStub{}, + &marshallerMock.MarshalizerStub{}, &testscommon.HasherStub{}, &testscommon.RaterMock{}, &storage.StorerStub{}, @@ -584,7 +586,7 @@ func TestCreateNodesCoordinator(t *testing.T) { 0, make(chan endProcess.ArgEndProcess, 1), &nodeTypeProviderMock.NodeTypeProviderStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, ) require.Nil(t, err) diff --git a/factory/consensus/consensusComponents.go b/factory/consensus/consensusComponents.go index 0cabc983f05..9681bdde7d0 100644 --- a/factory/consensus/consensusComponents.go +++ b/factory/consensus/consensusComponents.go @@ -510,6 +510,7 @@ func (ccf *consensusComponentsFactory) createArgsBaseAccountsSyncer(trieStorageM CheckNodesOnDisk: ccf.config.TrieSync.CheckNodesOnDisk, UserAccountsSyncStatisticsHandler: statistics.NewTrieSyncStatistics(), AppStatusHandler: disabled.NewAppStatusHandler(), + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), } } diff --git a/factory/consensus/consensusComponents_test.go b/factory/consensus/consensusComponents_test.go index 184cb8d3d11..54f6a4cf4de 100644 --- a/factory/consensus/consensusComponents_test.go +++ b/factory/consensus/consensusComponents_test.go @@ -25,9 +25,11 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" factoryMocks "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" outportMocks "github.com/multiversx/mx-chain-go/testscommon/outport" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" @@ -43,7 +45,7 @@ func createMockConsensusComponentsFactoryArgs() consensusComp.ConsensusComponent Config: testscommon.GetGeneralConfig(), BootstrapRoundIndex: 0, CoreComponents: &mock.CoreComponentsMock{ - IntMarsh: &testscommon.MarshalizerStub{}, + IntMarsh: &marshallerMock.MarshalizerStub{}, Hash: &testscommon.HasherStub{ SizeCalled: func() int { return 1 @@ -63,8 +65,9 @@ func createMockConsensusComponentsFactoryArgs() consensusComp.ConsensusComponent return 2 }, }, - EpochChangeNotifier: &epochNotifier.EpochNotifierStub{}, - StartTime: time.Time{}, + EpochChangeNotifier: &epochNotifier.EpochNotifierStub{}, + StartTime: time.Time{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, NetworkComponents: &testsMocks.NetworkComponentsStub{ Messenger: &p2pmocks.MessengerStub{}, diff --git a/factory/core/coreComponentsHandler_test.go b/factory/core/coreComponentsHandler_test.go index 9c22a9a2f22..a52568821da 100644 --- a/factory/core/coreComponentsHandler_test.go +++ b/factory/core/coreComponentsHandler_test.go @@ -8,8 +8,8 @@ import ( errorsMx "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/factory" coreComp "github.com/multiversx/mx-chain-go/factory/core" - "github.com/multiversx/mx-chain-go/testscommon" componentsMock "github.com/multiversx/mx-chain-go/testscommon/components" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/require" ) @@ -117,7 +117,7 @@ func TestManagedCoreComponents(t *testing.T) { require.NotNil(t, managedCoreComponents.ChanStopNodeProcess()) require.NotNil(t, managedCoreComponents.NodeTypeProvider()) require.NotNil(t, managedCoreComponents.EnableEpochsHandler()) - require.Nil(t, managedCoreComponents.SetInternalMarshalizer(&testscommon.MarshalizerStub{})) + require.Nil(t, managedCoreComponents.SetInternalMarshalizer(&marshallerMock.MarshalizerStub{})) require.Equal(t, factory.CoreComponentsName, managedCoreComponents.String()) }) diff --git a/factory/heartbeat/heartbeatV2Components_test.go b/factory/heartbeat/heartbeatV2Components_test.go index 06ff8958b40..46587997ecf 100644 --- a/factory/heartbeat/heartbeatV2Components_test.go +++ b/factory/heartbeat/heartbeatV2Components_test.go @@ -21,6 +21,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/mainFactoryMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" @@ -43,7 +44,7 @@ func createMockHeartbeatV2ComponentsFactoryArgs() heartbeatComp.ArgHeartbeatV2Co }, CoreComponents: &factory.CoreComponentsHolderStub{ InternalMarshalizerCalled: func() marshal.Marshalizer { - return &testscommon.MarshalizerStub{} + return &marshallerMock.MarshalizerStub{} }, HardforkTriggerPubKeyCalled: func() []byte { return []byte("hardfork pub key") diff --git a/factory/mock/accountFactoryStub.go b/factory/mock/accountFactoryStub.go deleted file mode 100644 index 2912f6fedd1..00000000000 --- a/factory/mock/accountFactoryStub.go +++ /dev/null @@ -1,18 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-vm-common-go" - -// AccountsFactoryStub - -type AccountsFactoryStub struct { - CreateAccountCalled func(address []byte) (vmcommon.AccountHandler, error) -} - -// CreateAccount - -func (afs *AccountsFactoryStub) CreateAccount(address []byte) (vmcommon.AccountHandler, error) { - return afs.CreateAccountCalled(address) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (afs *AccountsFactoryStub) IsInterfaceNil() bool { - return afs == nil -} diff --git a/factory/processing/blockProcessorCreator_test.go b/factory/processing/blockProcessorCreator_test.go index 0982521c963..1d874717a65 100644 --- a/factory/processing/blockProcessorCreator_test.go +++ b/factory/processing/blockProcessorCreator_test.go @@ -97,11 +97,19 @@ func Test_newBlockProcessorCreatorForMeta(t *testing.T) { trieStorageManagers[dataRetriever.UserAccountsUnit.String()] = storageManagerUser trieStorageManagers[dataRetriever.PeerAccountsUnit.String()] = storageManagerPeer + argsAccCreator := state.ArgsAccountCreation{ + Hasher: coreComponents.Hasher(), + Marshaller: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), + } + accCreator, _ := factoryState.NewAccountCreator(argsAccCreator) + accounts, err := createAccountAdapter( &mock.MarshalizerMock{}, &hashingMocks.HasherMock{}, - factoryState.NewAccountCreator(), + accCreator, trieStorageManagers[dataRetriever.UserAccountsUnit.String()], + coreComponents.EnableEpochsHandler(), ) require.Nil(t, err) @@ -182,8 +190,9 @@ func createAccountAdapter( hasher hashing.Hasher, accountFactory state.AccountFactory, trieStorage common.StorageManager, + handler common.EnableEpochsHandler, ) (state.AccountsAdapter, error) { - tr, err := trie.NewTrie(trieStorage, marshaller, hasher, 5) + tr, err := trie.NewTrie(trieStorage, marshaller, hasher, handler, 5) if err != nil { return nil, err } diff --git a/factory/processing/processComponents.go b/factory/processing/processComponents.go index 916712488f6..8550fd67542 100644 --- a/factory/processing/processComponents.go +++ b/factory/processing/processComponents.go @@ -65,6 +65,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding/networksharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/cache" storageFactory "github.com/multiversx/mx-chain-go/storage/factory" @@ -884,7 +885,7 @@ func (pcf *processComponentsFactory) indexAndReturnGenesisAccounts() (map[string LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = pcf.state.AccountsAdapter().GetAllLeaves(leavesChannels, context.Background(), rootHash) + err = pcf.state.AccountsAdapter().GetAllLeaves(leavesChannels, context.Background(), rootHash, parsers.NewMainTrieLeafParser()) if err != nil { return map[string]*alteredAccount.AlteredAccount{}, err } @@ -928,7 +929,12 @@ func (pcf *processComponentsFactory) indexAndReturnGenesisAccounts() (map[string } func (pcf *processComponentsFactory) unmarshalUserAccount(address []byte, userAccountsBytes []byte) (state.UserAccountHandler, error) { - userAccount, err := state.NewUserAccount(address) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: pcf.coreData.Hasher(), + Marshaller: pcf.coreData.InternalMarshalizer(), + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), + } + userAccount, err := state.NewUserAccount(address, argsAccCreation) if err != nil { return nil, err } @@ -1543,6 +1549,7 @@ func (pcf *processComponentsFactory) createStorageRequestersForMeta( ManualEpochStartNotifier: manualEpochStartNotifier, ChanGracefullyClose: pcf.coreData.ChanStopNodeProcess(), SnapshotsEnabled: pcf.flagsConfig.SnapshotsEnabled, + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), } return storagerequesterscontainer.NewMetaRequestersContainerFactory(requestersContainerFactoryArgs) @@ -1572,6 +1579,7 @@ func (pcf *processComponentsFactory) createStorageRequestersForShard( ManualEpochStartNotifier: manualEpochStartNotifier, ChanGracefullyClose: pcf.coreData.ChanStopNodeProcess(), SnapshotsEnabled: pcf.flagsConfig.SnapshotsEnabled, + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), } return storagerequesterscontainer.NewShardRequestersContainerFactory(requestersContainerFactoryArgs) diff --git a/factory/processing/processComponents_test.go b/factory/processing/processComponents_test.go index 5287ceb46ff..4957aa60b50 100644 --- a/factory/processing/processComponents_test.go +++ b/factory/processing/processComponents_test.go @@ -38,11 +38,13 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" factoryMocks "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/guardianMocks" "github.com/multiversx/mx-chain-go/testscommon/mainFactoryMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/outport" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" @@ -183,7 +185,7 @@ func createMockProcessComponentsFactoryArgs() processComp.ProcessComponentsFacto Hash: blake2b.NewBlake2b(), TxVersionCheckHandler: &testscommon.TxVersionCheckerStub{}, RatingHandler: &testscommon.RaterMock{}, - EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, EpochNotifierWithConfirm: &updateMocks.EpochStartNotifierStub{}, RoundHandlerField: &testscommon.RoundHandlerMock{}, @@ -405,7 +407,7 @@ func TestNewProcessComponentsFactory(t *testing.T) { AddrPubKeyConv: &testscommon.PubkeyConverterStub{}, EpochChangeNotifier: &epochNotifier.EpochNotifierStub{}, ValPubKeyConv: &testscommon.PubkeyConverterStub{}, - IntMarsh: &testscommon.MarshalizerStub{}, + IntMarsh: &marshallerMock.MarshalizerStub{}, UInt64ByteSliceConv: nil, } pcf, err := processComp.NewProcessComponentsFactory(args) @@ -788,7 +790,7 @@ func TestProcessComponentsFactory_Create(t *testing.T) { stateCompMock := factoryMocks.NewStateComponentsMockFromRealComponent(args.State) realAccounts := stateCompMock.AccountsAdapter() stateCompMock.Accounts = &state.AccountsStub{ - GetAllLeavesCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + GetAllLeavesCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeavesParser common.TrieLeafParser) error { close(leavesChannels.LeavesChan) leavesChannels.ErrChan.Close() return expectedErr @@ -823,7 +825,7 @@ func TestProcessComponentsFactory_Create(t *testing.T) { stateCompMock := factoryMocks.NewStateComponentsMockFromRealComponent(args.State) realAccounts := stateCompMock.AccountsAdapter() stateCompMock.Accounts = &state.AccountsStub{ - GetAllLeavesCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + GetAllLeavesCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeavesParser common.TrieLeafParser) error { addrOk, _ := addrPubKeyConv.Decode("erd17c4fs6mz2aa2hcvva2jfxdsrdknu4220496jmswer9njznt22eds0rxlr4") addrNOK, _ := addrPubKeyConv.Decode("erd1ulhw20j7jvgfgak5p05kv667k5k9f320sgef5ayxkt9784ql0zssrzyhjp") leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage(addrOk, []byte("value")) // coverage @@ -840,7 +842,7 @@ func TestProcessComponentsFactory_Create(t *testing.T) { coreCompStub := factoryMocks.NewCoreComponentsHolderStubFromRealComponent(args.CoreData) cnt := 0 coreCompStub.InternalMarshalizerCalled = func() marshal.Marshalizer { - return &testscommon.MarshalizerStub{ + return &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { cnt++ if cnt == 1 { @@ -876,7 +878,7 @@ func TestProcessComponentsFactory_Create(t *testing.T) { realStateComp := args.State args.State = &factoryMocks.StateComponentsMock{ Accounts: &state.AccountsStub{ - GetAllLeavesCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + GetAllLeavesCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeavesParser common.TrieLeafParser) error { close(leavesChannels.LeavesChan) leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) leavesChannels.ErrChan.Close() @@ -917,7 +919,7 @@ func TestProcessComponentsFactory_Create(t *testing.T) { realStateComp := args.State args.State = &factoryMocks.StateComponentsMock{ Accounts: &state.AccountsStub{ - GetAllLeavesCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + GetAllLeavesCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeavesParser common.TrieLeafParser) error { leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("invalid addr"), []byte("value")) close(leavesChannels.LeavesChan) leavesChannels.ErrChan.Close() @@ -934,7 +936,7 @@ func TestProcessComponentsFactory_Create(t *testing.T) { } coreCompStub := factoryMocks.NewCoreComponentsHolderStubFromRealComponent(args.CoreData) coreCompStub.InternalMarshalizerCalled = func() marshal.Marshalizer { - return &testscommon.MarshalizerStub{ + return &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { return nil }, diff --git a/factory/state/stateComponents.go b/factory/state/stateComponents.go index 1778f0e103c..331b83d8fd2 100644 --- a/factory/state/stateComponents.go +++ b/factory/state/stateComponents.go @@ -112,7 +112,16 @@ func (scf *stateComponentsFactory) Create() (*stateComponents, error) { } func (scf *stateComponentsFactory) createAccountsAdapters(triesContainer common.TriesHolder) (state.AccountsAdapter, state.AccountsAdapter, state.AccountsRepository, error) { - accountFactory := factoryState.NewAccountCreator() + argsAccCreator := state.ArgsAccountCreation{ + Hasher: scf.core.Hasher(), + Marshaller: scf.core.InternalMarshalizer(), + EnableEpochsHandler: scf.core.EnableEpochsHandler(), + } + accountFactory, err := factoryState.NewAccountCreator(argsAccCreator) + if err != nil { + return nil, nil, nil, err + } + merkleTrie := triesContainer.Get([]byte(dataRetriever.UserAccountsUnit.String())) storagePruning, err := scf.newStoragePruningManager() if err != nil { diff --git a/genesis/mock/userAccountMock.go b/genesis/mock/userAccountMock.go index f2ae6ecf136..a6a8f873434 100644 --- a/genesis/mock/userAccountMock.go +++ b/genesis/mock/userAccountMock.go @@ -1,9 +1,11 @@ package mock import ( + "context" "errors" "math/big" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" ) @@ -140,7 +142,7 @@ func (uam *UserAccountMock) GetUserName() []byte { } // SaveDirtyData - -func (uam *UserAccountMock) SaveDirtyData(_ common.Trie) (map[string][]byte, error) { +func (uam *UserAccountMock) SaveDirtyData(_ common.Trie) ([]core.TrieData, error) { return nil, nil } @@ -148,3 +150,8 @@ func (uam *UserAccountMock) SaveDirtyData(_ common.Trie) (map[string][]byte, err func (uam *UserAccountMock) IsGuarded() bool { return false } + +// GetAllLeaves - +func (uam *UserAccountMock) GetAllLeaves(_ *common.TrieIteratorChannels, _ context.Context) error { + return nil +} diff --git a/genesis/process/genesisBlockCreator.go b/genesis/process/genesisBlockCreator.go index 2fde795be1f..b396cdcd94e 100644 --- a/genesis/process/genesisBlockCreator.go +++ b/genesis/process/genesisBlockCreator.go @@ -23,6 +23,7 @@ import ( "github.com/multiversx/mx-chain-go/process/smartContract/hooks" "github.com/multiversx/mx-chain-go/process/smartContract/hooks/counters" "github.com/multiversx/mx-chain-go/sharding" + "github.com/multiversx/mx-chain-go/state" factoryState "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/statusHandler" @@ -116,6 +117,7 @@ func (gbc *genesisBlockCreator) createHardForkImportHandler() error { StorageConfig: gbc.arg.HardForkConfig.ImportStateStorageConfig, TrieStorageManagers: gbc.arg.TrieStorageManagers, AddressConverter: gbc.arg.Core.AddressPubKeyConverter(), + EnableEpochsHandler: gbc.arg.Core.EnableEpochsHandler(), } importHandler, err := hardfork.NewStateImport(argsHardForkImport) if err != nil { @@ -491,14 +493,25 @@ func (gbc *genesisBlockCreator) getNewArgForShard(shardID uint32) (ArgsGenesisBl newArgument.Data = newArgument.Data.Clone().(dataComponentsHandler) return newArgument, nil } - newArgument := gbc.arg // copy the arguments + + argsAccCreator := state.ArgsAccountCreation{ + Hasher: newArgument.Core.Hasher(), + Marshaller: newArgument.Core.InternalMarshalizer(), + EnableEpochsHandler: newArgument.Core.EnableEpochsHandler(), + } + accCreator, err := factoryState.NewAccountCreator(argsAccCreator) + if err != nil { + return ArgsGenesisBlockCreator{}, err + } + newArgument.Accounts, err = createAccountAdapter( newArgument.Core.InternalMarshalizer(), newArgument.Core.Hasher(), - factoryState.NewAccountCreator(), + accCreator, gbc.arg.TrieStorageManagers[dataRetriever.UserAccountsUnit.String()], gbc.arg.Core.AddressPubKeyConverter(), + newArgument.Core.EnableEpochsHandler(), ) if err != nil { return ArgsGenesisBlockCreator{}, fmt.Errorf("'%w' while generating an in-memory accounts adapter for shard %d", diff --git a/genesis/process/genesisBlockCreator_test.go b/genesis/process/genesisBlockCreator_test.go index db4e29072d8..ea7939f3d2a 100644 --- a/genesis/process/genesisBlockCreator_test.go +++ b/genesis/process/genesisBlockCreator_test.go @@ -28,6 +28,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" @@ -71,7 +72,7 @@ func createMockArgument( Chain: "chainID", TxVersionCheck: &testscommon.TxVersionCheckerStub{}, MinTxVersion: 1, - EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, Data: &mock.DataComponentsMock{ Storage: &storageCommon.ChainStorerStub{ @@ -150,13 +151,21 @@ func createMockArgument( SelfShardId: 0, } - var err error + argsAccCreator := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + accCreator, err := factoryState.NewAccountCreator(argsAccCreator) + require.Nil(t, err) + arg.Accounts, err = createAccountAdapter( &mock.MarshalizerMock{}, &hashingMocks.HasherMock{}, - factoryState.NewAccountCreator(), + accCreator, trieStorageManagers[dataRetriever.UserAccountsUnit.String()], &testscommon.PubkeyConverterMock{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) require.Nil(t, err) diff --git a/genesis/process/memoryComponents.go b/genesis/process/memoryComponents.go index 799e9e75d6b..623c6f69f12 100644 --- a/genesis/process/memoryComponents.go +++ b/genesis/process/memoryComponents.go @@ -19,8 +19,9 @@ func createAccountAdapter( accountFactory state.AccountFactory, trieStorage common.StorageManager, addressConverter core.PubkeyConverter, + enableEpochsHandler common.EnableEpochsHandler, ) (state.AccountsAdapter, error) { - tr, err := trie.NewTrie(trieStorage, marshaller, hasher, maxTrieLevelInMemory) + tr, err := trie.NewTrie(trieStorage, marshaller, hasher, enableEpochsHandler, maxTrieLevelInMemory) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index 7e6cebc7912..276de428a2b 100644 --- a/go.mod +++ b/go.mod @@ -14,15 +14,15 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/mitchellh/mapstructure v1.5.0 github.com/multiversx/mx-chain-communication-go v1.0.2 - github.com/multiversx/mx-chain-core-go v1.2.5 + github.com/multiversx/mx-chain-core-go v1.2.6 github.com/multiversx/mx-chain-crypto-go v1.2.6 github.com/multiversx/mx-chain-es-indexer-go v1.4.4 github.com/multiversx/mx-chain-logger-go v1.0.11 - github.com/multiversx/mx-chain-storage-go v1.0.10 - github.com/multiversx/mx-chain-vm-common-go v1.4.5 - github.com/multiversx/mx-chain-vm-v1_2-go v1.2.55 - github.com/multiversx/mx-chain-vm-v1_3-go v1.3.56 - github.com/multiversx/mx-chain-vm-v1_4-go v1.4.82 + github.com/multiversx/mx-chain-storage-go v1.0.11 + github.com/multiversx/mx-chain-vm-common-go v1.4.6 + github.com/multiversx/mx-chain-vm-v1_2-go v1.2.56 + github.com/multiversx/mx-chain-vm-v1_3-go v1.3.57 + github.com/multiversx/mx-chain-vm-v1_4-go v1.4.83 github.com/pelletier/go-toml v1.9.3 github.com/pkg/errors v0.9.1 github.com/shirou/gopsutil v3.21.11+incompatible diff --git a/go.sum b/go.sum index 82a1022faa0..7730c58d485 100644 --- a/go.sum +++ b/go.sum @@ -622,8 +622,9 @@ github.com/multiversx/mx-chain-core-go v1.1.30/go.mod h1:8gGEQv6BWuuJwhd25qqhCOZ github.com/multiversx/mx-chain-core-go v1.2.1-0.20230510143029-ab37792342df/go.mod h1:jzYFSiYBuO0dGpGFXnZWSwcwcKP7Flyn/X41y4zIQrQ= github.com/multiversx/mx-chain-core-go v1.2.1/go.mod h1:8gGEQv6BWuuJwhd25qqhCOZbBSv9mk+hLeKvinSaSMk= github.com/multiversx/mx-chain-core-go v1.2.4/go.mod h1:jzYFSiYBuO0dGpGFXnZWSwcwcKP7Flyn/X41y4zIQrQ= -github.com/multiversx/mx-chain-core-go v1.2.5 h1:uIZSqRygJAxv+pGuZnoSMwS4t10C/paasuwps5nxrIQ= github.com/multiversx/mx-chain-core-go v1.2.5/go.mod h1:jzYFSiYBuO0dGpGFXnZWSwcwcKP7Flyn/X41y4zIQrQ= +github.com/multiversx/mx-chain-core-go v1.2.6 h1:fD5cMsByM1kgvNI+uGCQGlhvr+TrV7FPvJlXT4ubYdg= +github.com/multiversx/mx-chain-core-go v1.2.6/go.mod h1:jzYFSiYBuO0dGpGFXnZWSwcwcKP7Flyn/X41y4zIQrQ= github.com/multiversx/mx-chain-crypto-go v1.2.6 h1:yxsjAQGh62los+iYmORMfh3w9qen0xbYlmwU0juNSeg= github.com/multiversx/mx-chain-crypto-go v1.2.6/go.mod h1:rOj0Rr19HTOYt9YTeym7RKxlHt91NXln3LVKjHKVmA0= github.com/multiversx/mx-chain-es-indexer-go v1.4.4 h1:3k8pB1AEILlNXL2ggSnP43uqVBQQg3hbx7351IcGbh0= @@ -631,18 +632,17 @@ github.com/multiversx/mx-chain-es-indexer-go v1.4.4/go.mod h1:IAFuU3LhjVfs3+Sf4T github.com/multiversx/mx-chain-logger-go v1.0.11 h1:DFsHa+sc5fKwhDR50I8uBM99RTDTEW68ESyr5ALRDwE= github.com/multiversx/mx-chain-logger-go v1.0.11/go.mod h1:1srDkP0DQucWQ+rYfaq0BX2qLnULsUdRPADpYUTM6dA= github.com/multiversx/mx-chain-storage-go v1.0.8/go.mod h1:lEkFYFe6taKYxqO1einNgT1esY3K9Pj6vPnoCwV9C3U= -github.com/multiversx/mx-chain-storage-go v1.0.10 h1:5rzPMME+CEJyoGGJ1tAb6ISnPmr68VFvGoKo0hF0WtU= -github.com/multiversx/mx-chain-storage-go v1.0.10/go.mod h1:VP9fwyFBmbmDzahUuu0IeGX/dKG3iBWjN6FSQ6YtVaI= +github.com/multiversx/mx-chain-storage-go v1.0.11 h1:u4ZsfIXEU3nJWRUxyAswhBn2pT6tJkKRwf9pra4CpzA= +github.com/multiversx/mx-chain-storage-go v1.0.11/go.mod h1:VP9fwyFBmbmDzahUuu0IeGX/dKG3iBWjN6FSQ6YtVaI= github.com/multiversx/mx-chain-vm-common-go v1.4.1/go.mod h1:K6yCdro8VohzYI6GwjGzTO+fJiPgO5coo2sgQb+zA24= -github.com/multiversx/mx-chain-vm-common-go v1.4.4/go.mod h1:+AjDwO/RJwQ75dzHJ/gBxmi5uTdICdhAo8bGNHTf7Yk= -github.com/multiversx/mx-chain-vm-common-go v1.4.5 h1:/pIMGSGqNJXbfAKOqigU2yapuBlosMCJiu6r+r+XcHE= -github.com/multiversx/mx-chain-vm-common-go v1.4.5/go.mod h1:+AjDwO/RJwQ75dzHJ/gBxmi5uTdICdhAo8bGNHTf7Yk= -github.com/multiversx/mx-chain-vm-v1_2-go v1.2.55 h1:jvBLu7JoitavahMDCkfOGYWjgXGBOe+3JJ0hNxj9AZM= -github.com/multiversx/mx-chain-vm-v1_2-go v1.2.55/go.mod h1:jCNgHGyj0JoLAsmijOSVK0G+yphccp9gIKsp/mRguF4= -github.com/multiversx/mx-chain-vm-v1_3-go v1.3.56 h1:VXveqaT/wdipfhIdUHXxFderY3+KxtFEbrDkF+zirr8= -github.com/multiversx/mx-chain-vm-v1_3-go v1.3.56/go.mod h1:guKkvnEDwGPaysZOVa+SaHEyiFDRJkFSVu0VE7jbk4k= -github.com/multiversx/mx-chain-vm-v1_4-go v1.4.82 h1:f0jL0jMPayN+/J/ZoK9sDRLggqvUp+/DJmu0dVTQNq8= -github.com/multiversx/mx-chain-vm-v1_4-go v1.4.82/go.mod h1:tKdkDQXDPFE5vAYOAJOq2iiTibi9KeiasNWsmA4nEmk= +github.com/multiversx/mx-chain-vm-common-go v1.4.6 h1:/a5cRc9i1auexetg2PBBJejBNWPUdn2qjN/AlKfHrLw= +github.com/multiversx/mx-chain-vm-common-go v1.4.6/go.mod h1:cnMvZN8+4oDkjloTZVExlf8ShkMGWbbDb5/D//wLT/k= +github.com/multiversx/mx-chain-vm-v1_2-go v1.2.56 h1:eGnP/x72cwhxJ3OJqHJt3TmAClvYp78IIQKYuKjEB8E= +github.com/multiversx/mx-chain-vm-v1_2-go v1.2.56/go.mod h1:Z4Vnc0jdaMvOC8Dx621aduLY/hXQAIR0HJ3VbDhkBb0= +github.com/multiversx/mx-chain-vm-v1_3-go v1.3.57 h1:1FwDLIOmudfBiB5lBH1RibBNHzrOMSUrEqgIRChEB2I= +github.com/multiversx/mx-chain-vm-v1_3-go v1.3.57/go.mod h1:uJ2fynux4EslnMLoQIuxBd0DUpfSrYjBkQFOU8HWKTY= +github.com/multiversx/mx-chain-vm-v1_4-go v1.4.83 h1:CNlKDlRBblmvz+0l0j5oqbC8qLjogKI51OiqtKvKlXY= +github.com/multiversx/mx-chain-vm-v1_4-go v1.4.83/go.mod h1:l/FL5GQ0tH4taTUecHnKgA5FNrYJaxXD3nOrP0Wan0Q= github.com/multiversx/mx-components-big-int v0.1.1 h1:695mYPKYOrmGEGgRH4/pZruDoe3CPP1LHrBxKfvj5l4= github.com/multiversx/mx-components-big-int v0.1.1/go.mod h1:0QrcFdfeLgJ/am10HGBeH0G0DNF+0Qx1E4DS/iozQls= github.com/multiversx/protobuf v1.3.2 h1:RaNkxvGTGbA0lMcnHAN24qE1G1i+Xs5yHA6MDvQ4mSM= diff --git a/heartbeat/monitor/monitor_test.go b/heartbeat/monitor/monitor_test.go index 65a5319d046..83ae428fbee 100644 --- a/heartbeat/monitor/monitor_test.go +++ b/heartbeat/monitor/monitor_test.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/heartbeat/mock" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/assert" ) @@ -22,7 +23,7 @@ func createMockHeartbeatV2MonitorArgs() ArgHeartbeatV2Monitor { return ArgHeartbeatV2Monitor{ Cache: testscommon.NewCacherMock(), PubKeyConverter: &testscommon.PubkeyConverterMock{}, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, MaxDurationPeerUnresponsive: time.Second * 3, HideInactiveValidatorInterval: time.Second * 5, ShardId: 0, @@ -43,7 +44,7 @@ func createHeartbeatMessage(active bool, publicKeyBytes []byte) *heartbeat.Heart Timestamp: messageTimestamp, } - marshaller := testscommon.MarshalizerMock{} + marshaller := marshallerMock.MarshalizerMock{} payloadBytes, _ := marshaller.Marshal(payload) return &heartbeat.HeartbeatV2{ Payload: payloadBytes, @@ -320,7 +321,7 @@ func TestHeartbeatV2Monitor_GetHeartbeats(t *testing.T) { Timestamp: time.Now().Unix() - 30 + int64(i), // the last message will be the latest, so it will be returned } - marshaller := testscommon.MarshalizerMock{} + marshaller := marshallerMock.MarshalizerMock{} payloadBytes, _ := marshaller.Marshal(payload) providedMessages[i].Payload = payloadBytes @@ -356,7 +357,7 @@ func TestHeartbeatV2Monitor_GetHeartbeats(t *testing.T) { Timestamp: time.Now().Unix() - 30 + int64(i), // the last message will be the latest, so it will be returned } - marshaller := testscommon.MarshalizerMock{} + marshaller := marshallerMock.MarshalizerMock{} payloadBytes, _ := marshaller.Marshal(payload) providedMessages[i].Payload = payloadBytes } diff --git a/heartbeat/sender/baseSender_test.go b/heartbeat/sender/baseSender_test.go index f2848327b2b..dc19139fe29 100644 --- a/heartbeat/sender/baseSender_test.go +++ b/heartbeat/sender/baseSender_test.go @@ -5,8 +5,8 @@ import ( "time" "github.com/multiversx/mx-chain-go/heartbeat/mock" - "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" ) @@ -14,7 +14,7 @@ import ( func createMockBaseArgs() argBaseSender { return argBaseSender{ messenger: &p2pmocks.MessengerStub{}, - marshaller: &testscommon.MarshalizerMock{}, + marshaller: &marshallerMock.MarshalizerMock{}, topic: "topic", timeBetweenSends: time.Second, timeBetweenSendsWhenError: time.Second, diff --git a/heartbeat/sender/bootstrapSender_test.go b/heartbeat/sender/bootstrapSender_test.go index 974e9929e93..1f9dd524940 100644 --- a/heartbeat/sender/bootstrapSender_test.go +++ b/heartbeat/sender/bootstrapSender_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/heartbeat/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" ) @@ -19,7 +20,7 @@ import ( func createMockBootstrapSenderArgs() ArgBootstrapSender { return ArgBootstrapSender{ Messenger: &p2pmocks.MessengerStub{}, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, HeartbeatTopic: "hb-topic", HeartbeatTimeBetweenSends: time.Second, HeartbeatTimeBetweenSendsWhenError: time.Second, diff --git a/heartbeat/sender/heartbeatSender_test.go b/heartbeat/sender/heartbeatSender_test.go index 40d8f41db30..e4fd2c4bc3f 100644 --- a/heartbeat/sender/heartbeatSender_test.go +++ b/heartbeat/sender/heartbeatSender_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/heartbeat/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" ) @@ -219,7 +220,7 @@ func TestHeartbeatSender_Execute(t *testing.T) { argsBase := createMockBaseArgs() argsBase.timeBetweenSendsWhenError = time.Second * 3 argsBase.timeBetweenSends = time.Second * 2 - argsBase.marshaller = &testscommon.MarshalizerStub{ + argsBase.marshaller = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, @@ -269,7 +270,7 @@ func TestHeartbeatSender_execute(t *testing.T) { t.Parallel() argsBase := createMockBaseArgs() - argsBase.marshaller = &testscommon.MarshalizerStub{ + argsBase.marshaller = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, @@ -287,7 +288,7 @@ func TestHeartbeatSender_execute(t *testing.T) { argsBase := createMockBaseArgs() numOfCalls := 0 - argsBase.marshaller = &testscommon.MarshalizerStub{ + argsBase.marshaller = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { if numOfCalls < 1 { numOfCalls++ diff --git a/heartbeat/sender/multikeyHeartbeatSender_test.go b/heartbeat/sender/multikeyHeartbeatSender_test.go index 0d46c8facf2..fec7a216720 100644 --- a/heartbeat/sender/multikeyHeartbeatSender_test.go +++ b/heartbeat/sender/multikeyHeartbeatSender_test.go @@ -14,6 +14,7 @@ import ( "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/heartbeat/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" ) @@ -236,7 +237,7 @@ func TestMultikeyHeartbeatSender_Execute(t *testing.T) { argsBase := createMockBaseArgs() argsBase.timeBetweenSendsWhenError = time.Second * 3 argsBase.timeBetweenSends = time.Second * 2 - argsBase.marshaller = &testscommon.MarshalizerStub{ + argsBase.marshaller = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, diff --git a/heartbeat/sender/peerAuthenticationSender_test.go b/heartbeat/sender/peerAuthenticationSender_test.go index c1a073cdbcd..901ebf31d3e 100644 --- a/heartbeat/sender/peerAuthenticationSender_test.go +++ b/heartbeat/sender/peerAuthenticationSender_test.go @@ -22,6 +22,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/stretchr/testify/assert" @@ -259,7 +260,7 @@ func TestPeerAuthenticationSender_execute(t *testing.T) { assert.Fail(t, "should have not called Messenger.BroadcastCalled") }, } - argsBase.marshaller = &testscommon.MarshalizerStub{ + argsBase.marshaller = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, @@ -303,7 +304,7 @@ func TestPeerAuthenticationSender_execute(t *testing.T) { assert.Fail(t, "should have not called Messenger.BroadcastCalled") }, } - argsBase.marshaller = &testscommon.MarshalizerStub{ + argsBase.marshaller = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { numCalls++ if numCalls < 2 { diff --git a/heartbeat/sender/sender_test.go b/heartbeat/sender/sender_test.go index b6402821e77..bc9db68bad1 100644 --- a/heartbeat/sender/sender_test.go +++ b/heartbeat/sender/sender_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/heartbeat/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/stretchr/testify/assert" @@ -22,7 +23,7 @@ import ( func createMockSenderArgs() ArgSender { return ArgSender{ Messenger: &p2pmocks.MessengerStub{}, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, PeerAuthenticationTopic: "pa-topic", HeartbeatTopic: "hb-topic", PeerAuthenticationTimeBetweenSends: time.Second, diff --git a/integrationTests/benchmarks/loadFromTrie_test.go b/integrationTests/benchmarks/loadFromTrie_test.go index 711ddeba293..3e2833d7067 100644 --- a/integrationTests/benchmarks/loadFromTrie_test.go +++ b/integrationTests/benchmarks/loadFromTrie_test.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/storageunit" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" testStorage "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/hashesHolder/disabled" @@ -103,7 +104,7 @@ func generateTriesWithMaxDepth( ) []*keyForTrie { tries := make([]*keyForTrie, numTries) for i := 0; i < numTries; i++ { - tr, _ := trie.NewTrie(storage, marshaller, hasher, 2) + tr, _ := trie.NewTrie(storage, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 2) key := insertKeysIntoTrie(t, tr, numTrieLevels, numChildrenPerBranch) rootHash, _ := tr.RootHash() diff --git a/integrationTests/interface.go b/integrationTests/interface.go index 7a6de790497..634e1eb447e 100644 --- a/integrationTests/interface.go +++ b/integrationTests/interface.go @@ -110,5 +110,6 @@ type Facade interface { GetLastPoolNonceForSender(sender string) (uint64, error) GetTransactionsPoolNonceGapsForSender(sender string) (*common.TransactionsPoolNonceGapsForSenderApiResponse, error) GetAlteredAccountsForBlock(options dataApi.GetAlteredAccountsForBlockOptions) ([]*alteredAccount.AlteredAccount, error) + IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) IsInterfaceNil() bool } diff --git a/integrationTests/longTests/storage/storage_test.go b/integrationTests/longTests/storage/storage_test.go index 56474f26978..4bd0e903729 100644 --- a/integrationTests/longTests/storage/storage_test.go +++ b/integrationTests/longTests/storage/storage_test.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing/blake2b" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" "github.com/stretchr/testify/assert" @@ -117,7 +118,7 @@ func TestWriteContinuouslyInTree(t *testing.T) { trieStorage, _ := trie.CreateTrieStorageManager(storageManagerArgs, options) maxTrieLevelInMemory := uint(5) - tr, _ := trie.NewTrie(trieStorage, &marshal.JsonMarshalizer{}, blake2b.NewBlake2b(), maxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, &marshal.JsonMarshalizer{}, blake2b.NewBlake2b(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) defer func() { _ = store.DestroyUnit() diff --git a/integrationTests/mock/accountFactoryStub.go b/integrationTests/mock/accountFactoryStub.go deleted file mode 100644 index 2912f6fedd1..00000000000 --- a/integrationTests/mock/accountFactoryStub.go +++ /dev/null @@ -1,18 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-vm-common-go" - -// AccountsFactoryStub - -type AccountsFactoryStub struct { - CreateAccountCalled func(address []byte) (vmcommon.AccountHandler, error) -} - -// CreateAccount - -func (afs *AccountsFactoryStub) CreateAccount(address []byte) (vmcommon.AccountHandler, error) { - return afs.CreateAccountCalled(address) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (afs *AccountsFactoryStub) IsInterfaceNil() bool { - return afs == nil -} diff --git a/integrationTests/multiShard/hardFork/hardFork_test.go b/integrationTests/multiShard/hardFork/hardFork_test.go index 69a5ccfbdcf..d911cc22906 100644 --- a/integrationTests/multiShard/hardFork/hardFork_test.go +++ b/integrationTests/multiShard/hardFork/hardFork_test.go @@ -21,6 +21,7 @@ import ( vmFactory "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" factoryTests "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" @@ -567,6 +568,7 @@ func createHardForkExporter( return string(node.ChainID) } coreComponents.HardforkTriggerPubKeyField = []byte("provided hardfork pub key") + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{} cryptoComponents := integrationTests.GetDefaultCryptoComponents() cryptoComponents.BlockSig = node.OwnAccount.BlockSingleSigner diff --git a/integrationTests/multiShard/smartContract/dns/dns_test.go b/integrationTests/multiShard/smartContract/dns/dns_test.go index 4fdf6eb52f5..8608ff60785 100644 --- a/integrationTests/multiShard/smartContract/dns/dns_test.go +++ b/integrationTests/multiShard/smartContract/dns/dns_test.go @@ -12,7 +12,6 @@ import ( "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-core-go/hashing/keccak" - "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/genesis" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/multiShard/relayedTx" @@ -306,7 +305,7 @@ func checkUserNamesAreDeleted( dnsAcc, _ := acnt.(state.UserAccountHandler) keyFromTrie := "value_state" + string(keccak.NewKeccak().Compute(userName)) - value, _, err := dnsAcc.DataTrie().(common.Trie).Get([]byte(keyFromTrie)) + value, _, err := dnsAcc.RetrieveValue([]byte(keyFromTrie)) assert.Nil(t, err) assert.Nil(t, value) } diff --git a/integrationTests/nodesCoordinatorFactory.go b/integrationTests/nodesCoordinatorFactory.go index 543f7966595..1235fbd16b6 100644 --- a/integrationTests/nodesCoordinatorFactory.go +++ b/integrationTests/nodesCoordinatorFactory.go @@ -9,7 +9,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" ) @@ -49,7 +49,7 @@ func (tpn *IndexHashedNodesCoordinatorFactory) CreateNodesCoordinator(arg ArgInd Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } nodeShuffler, _ := nodesCoordinator.NewHashValidatorsShuffler(nodeShufflerArgs) argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ @@ -70,7 +70,7 @@ func (tpn *IndexHashedNodesCoordinatorFactory) CreateNodesCoordinator(arg ArgInd ChanStopNode: endProcess.GetDummyEndProcessChannel(), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, IsFullArchive: false, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ RefactorPeersMiniBlocksEnableEpochField: UnreachableEpoch, }, ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, @@ -103,7 +103,7 @@ func (ihncrf *IndexHashedNodesCoordinatorWithRaterFactory) CreateNodesCoordinato Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsWaitingListFixFlagEnabledField: true, IsBalanceWaitingListsFlagEnabledField: true, }, @@ -127,7 +127,7 @@ func (ihncrf *IndexHashedNodesCoordinatorWithRaterFactory) CreateNodesCoordinato ChanStopNode: endProcess.GetDummyEndProcessChannel(), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, IsFullArchive: false, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsWaitingListFixFlagEnabledField: true, RefactorPeersMiniBlocksEnableEpochField: UnreachableEpoch, }, diff --git a/integrationTests/state/genesisState/genesisState_test.go b/integrationTests/state/genesisState/genesisState_test.go index c7c8761ace6..306980f2ce6 100644 --- a/integrationTests/state/genesisState/genesisState_test.go +++ b/integrationTests/state/genesisState/genesisState_test.go @@ -16,8 +16,8 @@ import ( "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/mock" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/assert" ) @@ -44,7 +44,7 @@ func TestCreationOfTheGenesisState(t *testing.T) { PubkeyConverter: integrationTests.TestAddressPubkeyConverter, KeyGenerator: &mock.KeyGenMock{}, Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &testscommon.MarshalizerMock{}, + Marshalizer: &marshallerMock.MarshalizerMock{}, } accountsParser, err := parsing.NewAccountsParser(args) diff --git a/integrationTests/state/stateTrie/stateTrie_test.go b/integrationTests/state/stateTrie/stateTrie_test.go index f653c917308..96eafe5efda 100644 --- a/integrationTests/state/stateTrie/stateTrie_test.go +++ b/integrationTests/state/stateTrie/stateTrie_test.go @@ -36,6 +36,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" testStorage "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" @@ -263,7 +264,7 @@ func TestTrieDB_RecreateFromStorageShouldWork(t *testing.T) { trieStorage, _ := trie.NewTrieStorageManager(args) maxTrieLevelInMemory := uint(5) - tr1, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, hasher, maxTrieLevelInMemory) + tr1, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) key := hasher.Compute("key") value := hasher.Compute("value") @@ -1046,14 +1047,19 @@ func createAccounts( args.MainStorer = store trieStorage, _ := trie.NewTrieStorageManager(args) maxTrieLevelInMemory := uint(5) - tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 10) - + argsAccCreator := state.ArgsAccountCreation{ + Hasher: integrationTests.TestHasher, + Marshaller: integrationTests.TestMarshalizer, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + accCreator, _ := factory.NewAccountCreator(argsAccCreator) argsAccountsDB := state.ArgsAccountsDB{ Trie: tr, Hasher: integrationTests.TestHasher, Marshaller: integrationTests.TestMarshalizer, - AccountFactory: factory.NewAccountCreator(), + AccountFactory: accCreator, StoragePruningManager: spm, ProcessingMode: common.Normal, ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, @@ -2478,14 +2484,20 @@ func createAccountsDBTestSetup() *state.AccountsDB { args.GeneralConfig = generalCfg trieStorage, _ := trie.NewTrieStorageManager(args) maxTrieLevelInMemory := uint(5) - tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 10) + argsAccCreator := state.ArgsAccountCreation{ + Hasher: integrationTests.TestHasher, + Marshaller: integrationTests.TestMarshalizer, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + accCreator, _ := factory.NewAccountCreator(argsAccCreator) argsAccountsDB := state.ArgsAccountsDB{ Trie: tr, Hasher: integrationTests.TestHasher, Marshaller: integrationTests.TestMarshalizer, - AccountFactory: factory.NewAccountCreator(), + AccountFactory: accCreator, StoragePruningManager: spm, ProcessingMode: common.Normal, ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, diff --git a/integrationTests/state/stateTrieClose/stateTrieClose_test.go b/integrationTests/state/stateTrieClose/stateTrieClose_test.go index 7b96f2b39b1..9d99a178484 100644 --- a/integrationTests/state/stateTrieClose/stateTrieClose_test.go +++ b/integrationTests/state/stateTrieClose/stateTrieClose_test.go @@ -11,6 +11,8 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/state/parsers" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/goroutines" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" @@ -21,7 +23,7 @@ import ( func TestPatriciaMerkleTrie_Close(t *testing.T) { numLeavesToAdd := 200 trieStorage, _ := integrationTests.CreateTrieStorageManager(integrationTests.CreateMemUnit()) - tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, 5) + tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) for i := 0; i < numLeavesToAdd; i++ { _ = tr.Update([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))) @@ -36,7 +38,13 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + _ = tr.GetAllLeavesOnChannel( + leavesChannel1, + context.Background(), + rootHash, + keyBuilder.NewDisabledKeyBuilder(), + parsers.NewMainTrieLeafParser(), + ) time.Sleep(time.Second) // allow the go routine to start idx, _ := gc.Snapshot() diff := gc.DiffGoRoutines(idxInitial, idx) @@ -48,7 +56,13 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + _ = tr.GetAllLeavesOnChannel( + leavesChannel1, + context.Background(), + rootHash, + keyBuilder.NewDisabledKeyBuilder(), + parsers.NewMainTrieLeafParser(), + ) idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) assert.True(t, len(diff) <= 2) @@ -63,7 +77,13 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + _ = tr.GetAllLeavesOnChannel( + leavesChannel1, + context.Background(), + rootHash, + keyBuilder.NewDisabledKeyBuilder(), + parsers.NewMainTrieLeafParser(), + ) idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) assert.Equal(t, 3, len(diff), fmt.Sprintf("%v", diff)) @@ -78,7 +98,13 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - _ = tr.GetAllLeavesOnChannel(leavesChannel2, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + _ = tr.GetAllLeavesOnChannel( + leavesChannel2, + context.Background(), + rootHash, + keyBuilder.NewDisabledKeyBuilder(), + parsers.NewMainTrieLeafParser(), + ) time.Sleep(time.Second) // allow the go routine to start idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) diff --git a/integrationTests/state/stateTrieSync/stateTrieSync_test.go b/integrationTests/state/stateTrieSync/stateTrieSync_test.go index 6ef9c6e5d9a..98e88edc668 100644 --- a/integrationTests/state/stateTrieSync/stateTrieSync_test.go +++ b/integrationTests/state/stateTrieSync/stateTrieSync_test.go @@ -17,6 +17,7 @@ import ( "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" @@ -332,7 +333,7 @@ func testMultipleDataTriesSync(t *testing.T, numAccounts int, numDataTrieLeaves LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = accState.GetAllLeaves(leavesChannel, context.Background(), rootHash) + err = accState.GetAllLeaves(leavesChannel, context.Background(), rootHash, parsers.NewMainTrieLeafParser()) for range leavesChannel.LeavesChan { } require.Nil(t, err) @@ -360,7 +361,7 @@ func testMultipleDataTriesSync(t *testing.T, numAccounts int, numDataTrieLeaves LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = nRequester.AccntState.GetAllLeaves(leavesChannel, context.Background(), rootHash) + err = nRequester.AccntState.GetAllLeaves(leavesChannel, context.Background(), rootHash, parsers.NewMainTrieLeafParser()) assert.Nil(t, err) numLeaves := 0 for range leavesChannel.LeavesChan { @@ -562,7 +563,13 @@ func getNumLeaves(t *testing.T, tr common.Trie, rootHash []byte) int { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + err := tr.GetAllLeavesOnChannel( + leavesChannel, + context.Background(), + rootHash, + keyBuilder.NewDisabledKeyBuilder(), + parsers.NewMainTrieLeafParser(), + ) require.Nil(t, err) numLeaves := 0 @@ -591,6 +598,7 @@ func getUserAccountSyncerArgs(node *integrationTests.TestProcessorNode, version TrieSyncerVersion: version, UserAccountsSyncStatisticsHandler: statistics.NewTrieSyncStatistics(), AppStatusHandler: integrationTests.TestAppStatusHandler, + EnableEpochsHandler: node.EnableEpochsHandler, }, ShardId: 0, Throttler: thr, diff --git a/integrationTests/testConsensusNode.go b/integrationTests/testConsensusNode.go index c4b6f89c673..59620306e34 100644 --- a/integrationTests/testConsensusNode.go +++ b/integrationTests/testConsensusNode.go @@ -41,6 +41,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" testFactory "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" @@ -378,7 +379,7 @@ func (tcn *TestConsensusNode) initNodesCoordinator( ChanStopNode: endProcess.GetDummyEndProcessChannel(), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, IsFullArchive: false, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsWaitingListFixFlagEnabledField: true, }, ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, diff --git a/integrationTests/testHeartbeatNode.go b/integrationTests/testHeartbeatNode.go index c56468b73f3..d8457d853ab 100644 --- a/integrationTests/testHeartbeatNode.go +++ b/integrationTests/testHeartbeatNode.go @@ -46,6 +46,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" @@ -307,7 +308,7 @@ func CreateNodesWithTestHeartbeatNode( ChanStopNode: endProcess.GetDummyEndProcessChannel(), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, IsFullArchive: false, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, } nodesCoordinatorInstance, err := nodesCoordinator.NewIndexHashedNodesCoordinator(argumentsNodesCoordinator) @@ -353,7 +354,7 @@ func CreateNodesWithTestHeartbeatNode( ChanStopNode: endProcess.GetDummyEndProcessChannel(), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, IsFullArchive: false, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, } nodesCoordinatorInstance, err := nodesCoordinator.NewIndexHashedNodesCoordinator(argumentsNodesCoordinator) diff --git a/integrationTests/testInitializer.go b/integrationTests/testInitializer.go index b1c860ff006..9e79dfa8693 100644 --- a/integrationTests/testInitializer.go +++ b/integrationTests/testInitializer.go @@ -57,6 +57,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/guardianMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" testStorage "github.com/multiversx/mx-chain-go/testscommon/state" @@ -436,14 +437,23 @@ func CreateAccountsDB( accountType Type, trieStorageManager common.StorageManager, ) (*state.AccountsDB, common.Trie) { - tr, _ := trie.NewTrie(trieStorageManager, TestMarshalizer, TestHasher, maxTrieLevelInMemory) + return CreateAccountsDBWithEnableEpochsHandler(accountType, trieStorageManager, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) +} + +// CreateAccountsDBWithEnableEpochsHandler creates a new AccountsDb with the given enableEpochsHandler +func CreateAccountsDBWithEnableEpochsHandler( + accountType Type, + trieStorageManager common.StorageManager, + enableEpochsHandler common.EnableEpochsHandler, +) (*state.AccountsDB, common.Trie) { + tr, _ := trie.NewTrie(trieStorageManager, TestMarshalizer, TestHasher, enableEpochsHandler, maxTrieLevelInMemory) ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, HashesSize: 10000, } ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) - accountFactory := getAccountFactory(accountType) + accountFactory, _ := getAccountFactory(accountType, enableEpochsHandler) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 10) args := state.ArgsAccountsDB{ Trie: tr, @@ -461,14 +471,19 @@ func CreateAccountsDB( return adb, tr } -func getAccountFactory(accountType Type) state.AccountFactory { +func getAccountFactory(accountType Type, enableEpochsHandler common.EnableEpochsHandler) (state.AccountFactory, error) { switch accountType { case UserAccount: - return factory.NewAccountCreator() + argsAccCreator := state.ArgsAccountCreation{ + Hasher: TestHasher, + Marshaller: TestMarshalizer, + EnableEpochsHandler: enableEpochsHandler, + } + return factory.NewAccountCreator(argsAccCreator) case ValidatorAccount: - return factory.NewPeerAccountCreator() + return factory.NewPeerAccountCreator(), nil default: - return nil + return nil, fmt.Errorf("invalid account type provided") } } @@ -815,7 +830,7 @@ func CreateGenesisMetaBlock( newBlkc, _ := blockchain.NewMetaChain(&statusHandlerMock.AppStatusHandlerStub{}) trieStorage, _ := CreateTrieStorageManager(CreateMemUnit()) - newAccounts, _ := CreateAccountsDB(UserAccount, trieStorage) + newAccounts, _ := CreateAccountsDBWithEnableEpochsHandler(UserAccount, trieStorage, coreComponents.EnableEpochsHandler()) argsMetaGenesis.ShardCoordinator = newShardCoordinator argsMetaGenesis.Accounts = newAccounts @@ -912,7 +927,12 @@ func GenerateAddressJournalAccountAccountsDB() ([]byte, state.UserAccountHandler adr := CreateRandomAddress() trieStorage, _ := CreateTrieStorageManager(CreateMemUnit()) adb, _ := CreateAccountsDB(UserAccount, trieStorage) - account, _ := state.NewUserAccount(adr) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: TestHasher, + Marshaller: TestMarshaller, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + account, _ := state.NewUserAccount(adr, argsAccCreation) return adr, account, adb } @@ -999,7 +1019,7 @@ func CreateSimpleTxProcessor(accnts state.AccountsAdapter) process.TransactionPr BadTxForwarder: &mock.IntermediateTransactionHandlerMock{}, ArgsParser: smartContract.NewArgumentParser(), ScrForwarder: &mock.IntermediateTransactionHandlerMock{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, TxVersionChecker: &testscommon.TxVersionCheckerStub{}, GuardianChecker: &guardianMocks.GuardedAccountHandlerStub{}, } @@ -1017,7 +1037,7 @@ func CreateNewDefaultTrie() common.Trie { trieStorage, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorage, TestMarshalizer, TestHasher, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, TestMarshalizer, TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) return tr } diff --git a/integrationTests/testProcessorNode.go b/integrationTests/testProcessorNode.go index dbfa6bc4b72..4ee12ddf869 100644 --- a/integrationTests/testProcessorNode.go +++ b/integrationTests/testProcessorNode.go @@ -116,7 +116,6 @@ import ( storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/storageManager" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" - "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/update/trigger" "github.com/multiversx/mx-chain-go/vm" @@ -572,11 +571,11 @@ func (tpn *TestProcessorNode) initAccountDBsWithPruningStorer() { trieStorageManager := CreateTrieStorageManagerWithPruningStorer(tpn.ShardCoordinator, tpn.EpochStartNotifier) tpn.TrieContainer = state.NewDataTriesHolder() var stateTrie common.Trie - tpn.AccntState, stateTrie = CreateAccountsDB(UserAccount, trieStorageManager) + tpn.AccntState, stateTrie = CreateAccountsDBWithEnableEpochsHandler(UserAccount, trieStorageManager, tpn.EnableEpochsHandler) tpn.TrieContainer.Put([]byte(dataRetriever.UserAccountsUnit.String()), stateTrie) var peerTrie common.Trie - tpn.PeerState, peerTrie = CreateAccountsDB(ValidatorAccount, trieStorageManager) + tpn.PeerState, peerTrie = CreateAccountsDBWithEnableEpochsHandler(ValidatorAccount, trieStorageManager, tpn.EnableEpochsHandler) tpn.TrieContainer.Put([]byte(dataRetriever.PeerAccountsUnit.String()), peerTrie) tpn.TrieStorageManagers = make(map[string]common.StorageManager) @@ -588,11 +587,11 @@ func (tpn *TestProcessorNode) initAccountDBs(store storage.Storer) { trieStorageManager, _ := CreateTrieStorageManager(store) tpn.TrieContainer = state.NewDataTriesHolder() var stateTrie common.Trie - tpn.AccntState, stateTrie = CreateAccountsDB(UserAccount, trieStorageManager) + tpn.AccntState, stateTrie = CreateAccountsDBWithEnableEpochsHandler(UserAccount, trieStorageManager, tpn.EnableEpochsHandler) tpn.TrieContainer.Put([]byte(dataRetriever.UserAccountsUnit.String()), stateTrie) var peerTrie common.Trie - tpn.PeerState, peerTrie = CreateAccountsDB(ValidatorAccount, trieStorageManager) + tpn.PeerState, peerTrie = CreateAccountsDBWithEnableEpochsHandler(ValidatorAccount, trieStorageManager, tpn.EnableEpochsHandler) tpn.TrieContainer.Put([]byte(dataRetriever.PeerAccountsUnit.String()), peerTrie) tpn.TrieStorageManagers = make(map[string]common.StorageManager) @@ -3289,12 +3288,11 @@ func GetTokenIdentifier(nodes []*TestProcessorNode, ticker []byte) []byte { acc, _ := n.AccntState.LoadAccount(vm.ESDTSCAddress) userAcc, _ := acc.(state.UserAccountHandler) - rootHash, _ := userAcc.DataTrie().RootHash() chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - _ = userAcc.DataTrie().GetAllLeavesOnChannel(chLeaves, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) + _ = userAcc.GetAllLeaves(chLeaves, context.Background()) for leaf := range chLeaves.LeavesChan { if !bytes.HasPrefix(leaf.Key(), ticker) { continue diff --git a/integrationTests/testProcessorNodeWithCoordinator.go b/integrationTests/testProcessorNodeWithCoordinator.go index 7ab761e960f..44fde10f931 100644 --- a/integrationTests/testProcessorNodeWithCoordinator.go +++ b/integrationTests/testProcessorNodeWithCoordinator.go @@ -13,7 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage/cache" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" ) @@ -73,7 +73,7 @@ func CreateProcessorNodesWithNodesCoordinator( ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, ChanStopNode: endProcess.GetDummyEndProcessChannel(), IsFullArchive: false, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, } diff --git a/integrationTests/testProcessorNodeWithMultisigner.go b/integrationTests/testProcessorNodeWithMultisigner.go index 16011396d99..a4cac7eea1b 100644 --- a/integrationTests/testProcessorNodeWithMultisigner.go +++ b/integrationTests/testProcessorNodeWithMultisigner.go @@ -30,6 +30,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" @@ -398,7 +399,7 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } nodeShuffler, _ := nodesCoordinator.NewHashValidatorsShuffler(shufflerArgs) epochStartSubscriber := notifier.NewEpochStartSubscriptionHandler() @@ -429,7 +430,7 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( ChanStopNode: endProcess.GetDummyEndProcessChannel(), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, IsFullArchive: false, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, } nodesCoordinatorInstance, err := nodesCoordinator.NewIndexHashedNodesCoordinator(argumentsNodesCoordinator) @@ -544,7 +545,7 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( ChanStopNode: endProcess.GetDummyEndProcessChannel(), NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, IsFullArchive: false, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, } nodesCoord, err := nodesCoordinator.NewIndexHashedNodesCoordinator(argumentsNodesCoordinator) diff --git a/integrationTests/testProcessorNodeWithTestWebServer.go b/integrationTests/testProcessorNodeWithTestWebServer.go index aa07a327720..97e6bdab161 100644 --- a/integrationTests/testProcessorNodeWithTestWebServer.go +++ b/integrationTests/testProcessorNodeWithTestWebServer.go @@ -25,6 +25,7 @@ import ( txSimData "github.com/multiversx/mx-chain-go/process/txsimulator/data" "github.com/multiversx/mx-chain-go/process/txstatus" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" @@ -244,7 +245,7 @@ func createFacadeComponents(tpn *TestProcessorNode) (nodeFacade.ApiResolver, nod AlteredAccountsProvider: &testscommon.AlteredAccountsProviderStub{}, AccountsRepository: &state.AccountsRepositoryStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } blockAPIHandler, err := blockAPI.CreateAPIBlockProcessor(argsBlockAPI) log.LogIfError(err) diff --git a/integrationTests/testSyncNode.go b/integrationTests/testSyncNode.go index 74e6595d2f8..0bec036c39a 100644 --- a/integrationTests/testSyncNode.go +++ b/integrationTests/testSyncNode.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/outport" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" @@ -48,7 +49,7 @@ func (tpn *TestProcessorNode) initBlockProcessorWithSync() { coreComponents.HasherField = TestHasher coreComponents.Uint64ByteSliceConverterField = TestUint64Converter coreComponents.EpochNotifierField = tpn.EpochNotifier - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ RefactorPeersMiniBlocksEnableEpochField: UnreachableEpoch, } diff --git a/integrationTests/vm/testInitializer.go b/integrationTests/vm/testInitializer.go index cc306a99b51..71c825aefa5 100644 --- a/integrationTests/vm/testInitializer.go +++ b/integrationTests/vm/testInitializer.go @@ -54,6 +54,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/integrationtests" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" @@ -504,7 +505,7 @@ func CreateOneSCExecutorMockVM(accnts state.AccountsAdapter) vmcommon.VMExecutio NilCompiledSCStore: true, ConfigSCStorage: *defaultStorageConfig(), EpochNotifier: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, GasSchedule: CreateMockGasScheduleNotifier(), Counter: &testscommon.BlockChainHookCounterStub{}, MissingTrieNodesNotifier: &testscommon.MissingTrieNodesNotifierStub{}, @@ -1139,12 +1140,12 @@ func CreatePreparedTxProcessorWithVMsWithShardCoordinatorDBAndGas( gasScheduleNotifier core.GasScheduleNotifier, ) (*VMTestContext, error) { feeAccumulator, _ := postprocess.NewFeeAccumulator() - accounts := integrationtests.CreateAccountsDB(db) + epochNotifierInstance := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(enableEpochsConfig, epochNotifierInstance) + accounts := integrationtests.CreateAccountsDB(db, enableEpochsHandler) vmConfig := createDefaultVMConfig() wasmVMChangeLocker := &sync.RWMutex{} - epochNotifierInstance := forking.NewGenericEpochNotifier() - enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(enableEpochsConfig, epochNotifierInstance) chainHandler := &testscommon.ChainHandlerStub{} var err error diff --git a/integrationTests/vm/txsFee/asyncESDT_test.go b/integrationTests/vm/txsFee/asyncESDT_test.go index 0e3f0a41fc1..8ed863c9403 100644 --- a/integrationTests/vm/txsFee/asyncESDT_test.go +++ b/integrationTests/vm/txsFee/asyncESDT_test.go @@ -19,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/integrationTests/vm/txsFee/utils" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/require" ) @@ -538,7 +539,7 @@ func TestAsyncESDTCallForThirdContractShouldWork(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder, 1), ErrChan: errChan.NewErrChanWrapper(), } - err = testContext.Accounts.GetAllLeaves(leaves, context.Background(), roothash) + err = testContext.Accounts.GetAllLeaves(leaves, context.Background(), roothash, parsers.NewMainTrieLeafParser()) require.Nil(t, err) for range leaves.LeavesChan { diff --git a/integrationTests/vm/txsFee/migrateDataTrie_test.go b/integrationTests/vm/txsFee/migrateDataTrie_test.go new file mode 100644 index 00000000000..8a920bdc171 --- /dev/null +++ b/integrationTests/vm/txsFee/migrateDataTrie_test.go @@ -0,0 +1,327 @@ +//go:build !race +// +build !race + +// TODO remove build condition above to allow -race -short, after Wasm VM fix + +package txsFee + +import ( + "fmt" + "math" + "math/big" + "strconv" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data/smartContractResult" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/integrationTests/vm" + "github.com/multiversx/mx-chain-go/sharding" + "github.com/multiversx/mx-chain-go/state" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" +) + +type statsCollector interface { + GetTrieStats(address string, rootHash []byte) (common.TrieStatisticsHandler, error) +} + +type dataTrie interface { + UpdateWithVersion(key []byte, value []byte, version core.TrieNodeVersion) error +} + +func TestMigrateDataTrieBuiltInFunc(t *testing.T) { + t.Parallel() + + enableEpochs := config.EnableEpochs{ + AutoBalanceDataTriesEnableEpoch: 0, + } + shardCoordinator, _ := sharding.NewMultiShardCoordinator(3, 1) + gasScheduleNotifier := vm.CreateMockGasScheduleNotifier() + trieLoadPerNode := uint64(20000) + trieStorePerNode := uint64(50000) + gasScheduleNotifier.GasSchedule[core.BuiltInCostString]["TrieLoadPerNode"] = trieLoadPerNode + gasScheduleNotifier.GasSchedule[core.BuiltInCostString]["TrieStorePerNode"] = trieStorePerNode + sndAddr := []byte("12345678901234567890123456789111") + gasPrice := uint64(10) + + t.Run("deterministic trie", func(t *testing.T) { + t.Parallel() + + testContext, err := vm.CreatePreparedTxProcessorWithVMsWithShardCoordinatorDBAndGas(enableEpochs, shardCoordinator, integrationTests.CreateMemUnit(), gasScheduleNotifier) + require.Nil(t, err) + defer testContext.Close() + + senderBalance := big.NewInt(100000000) + sndNonce := uint64(0) + _, _ = vm.CreateAccount(testContext.Accounts, sndAddr, sndNonce, senderBalance) + + numDataTrieLeaves := 10 + keyGenerator := func(i int) []byte { + return []byte(strconv.Itoa(i)) + } + rootHash, _ := generateDataTrie(t, testContext, sndAddr, numDataTrieLeaves, keyGenerator) + + dtr := getAccountDataTrie(t, testContext, sndAddr) + stats, ok := dtr.(statsCollector) + require.True(t, ok) + + dts, err := stats.GetTrieStats("", rootHash) + require.Nil(t, err) + require.Equal(t, uint64(numDataTrieLeaves-1), dts.GetLeavesMigrationStats()[core.NotSpecified]) + require.Equal(t, uint64(1), dts.GetLeavesMigrationStats()[core.AutoBalanceEnabled]) + + // migrate first 2 leaves, return when loading the third leaf (not enough gas for the migration) + // 5 loads + 2 stores = 200k gas + gasLimit := uint64(220000) + migrateDataTrie(t, testContext, sndAddr, gasPrice, gasLimit, vmcommon.Ok) + testGasConsumed(t, testContext, gasLimit, 200000) + + // migrate 2 leaves, 4 loads + 2 stores = 180k gas + gasLimit = uint64(200000) + migrateDataTrie(t, testContext, sndAddr, gasPrice, gasLimit, vmcommon.Ok) + testGasConsumed(t, testContext, gasLimit, 180000) + + // do not start the migration process, not enough gas for at least one migration + gasLimit = uint64(50000) + migrateDataTrie(t, testContext, sndAddr, gasPrice, gasLimit, vmcommon.UserError) + testGasConsumed(t, testContext, gasLimit, 50000) + + // return after loading a branch node, not enough gas for the migration + gasLimit = uint64(70000) + migrateDataTrie(t, testContext, sndAddr, gasPrice, gasLimit, vmcommon.Ok) + testGasConsumed(t, testContext, gasLimit, 60000) + + // migrate 2 leaves, 5 loads + 2 stores = 200k gas + gasLimit = uint64(200000) + migrateDataTrie(t, testContext, sndAddr, gasPrice, gasLimit, vmcommon.Ok) + testGasConsumed(t, testContext, gasLimit, 200000) + + // migrate 2 leaves, 3 loads + 2 stores = 160k gas + gasLimit = uint64(200000) + migrateDataTrie(t, testContext, sndAddr, gasPrice, gasLimit, vmcommon.Ok) + testGasConsumed(t, testContext, gasLimit, 160000) + + // migrate 1 leaf, 2 loads + 1 store = 90k gas + gasLimit = uint64(200000) + migrateDataTrie(t, testContext, sndAddr, gasPrice, gasLimit, vmcommon.Ok) + testGasConsumed(t, testContext, gasLimit, 90000) + + // no leaf left to migrate, 1 load = 20k gas + gasLimit = uint64(200000) + migrateDataTrie(t, testContext, sndAddr, gasPrice, gasLimit, vmcommon.Ok) + testGasConsumed(t, testContext, gasLimit, 20000) + + err = dtr.Commit() + require.Nil(t, err) + + rootHash, err = dtr.RootHash() + require.Nil(t, err) + + dts, err = stats.GetTrieStats("", rootHash) + require.Nil(t, err) + require.Equal(t, uint64(0), dts.GetLeavesMigrationStats()[core.NotSpecified]) + require.Equal(t, uint64(numDataTrieLeaves), dts.GetLeavesMigrationStats()[core.AutoBalanceEnabled]) + }) + + t.Run("random trie - all leaves are migrated in multiple transactions", func(t *testing.T) { + t.Parallel() + + testContext, err := vm.CreatePreparedTxProcessorWithVMsWithShardCoordinatorDBAndGas(enableEpochs, shardCoordinator, integrationTests.CreateMemUnit(), gasScheduleNotifier) + require.Nil(t, err) + defer testContext.Close() + + sndNonce := uint64(0) + senderBalance := big.NewInt(math.MaxInt64) + _, _ = vm.CreateAccount(testContext.Accounts, sndAddr, sndNonce, senderBalance) + + numLeaves := 10000 + keyGenerator := func(i int) []byte { + return integrationTests.GenerateRandomSlice(32) + } + rootHash, allKeys := generateDataTrie(t, testContext, sndAddr, numLeaves, keyGenerator) + nonMigratedKeys := allKeys[1:] + + acc := getAccount(t, testContext, sndAddr) + dtr, ok := acc.DataTrie().(common.Trie) + require.True(t, ok) + stats, ok := dtr.(statsCollector) + require.True(t, ok) + + dts, err := stats.GetTrieStats("", rootHash) + require.Nil(t, err) + + numNonMigratedLeaves := dts.GetLeavesMigrationStats()[core.NotSpecified] + numMigratedLeaves := dts.GetLeavesMigrationStats()[core.AutoBalanceEnabled] + require.Equal(t, uint64(numLeaves-1), numNonMigratedLeaves) + require.Equal(t, uint64(1), numMigratedLeaves) + + gasLimit := uint64(100_000_000) + + numMigrateDataTrieCalls := 0 + maxExpectedNumCalls := 15 + for numNonMigratedLeaves > 0 { + migrateDataTrie(t, testContext, sndAddr, gasPrice, gasLimit, vmcommon.Ok) + numMigrateDataTrieCalls++ + + err = dtr.Commit() + require.Nil(t, err) + + rootHash, err = dtr.RootHash() + require.Nil(t, err) + + dts, err = stats.GetTrieStats("", rootHash) + require.Nil(t, err) + + require.True(t, dts.GetLeavesMigrationStats()[core.NotSpecified] < numNonMigratedLeaves) + require.True(t, dts.GetLeavesMigrationStats()[core.AutoBalanceEnabled] > numMigratedLeaves) + + numNonMigratedLeaves = dts.GetLeavesMigrationStats()[core.NotSpecified] + numMigratedLeaves = dts.GetLeavesMigrationStats()[core.AutoBalanceEnabled] + } + + require.True(t, numMigrateDataTrieCalls < maxExpectedNumCalls) + + require.Equal(t, uint64(0), dts.GetLeavesMigrationStats()[core.NotSpecified]) + require.Equal(t, uint64(numLeaves), dts.GetLeavesMigrationStats()[core.AutoBalanceEnabled]) + + err = testContext.Accounts.SaveAccount(acc) + require.Nil(t, err) + + acc = getAccount(t, testContext, sndAddr) + + for _, key := range nonMigratedKeys { + val, _, err := acc.RetrieveValue(key) + require.Nil(t, err) + require.Equal(t, key, val) + } + }) +} + +func generateDataTrie( + t *testing.T, + testContext *vm.VMTestContext, + accAddr []byte, + numLeaves int, + keyGenerator func(i int) []byte, +) ([]byte, [][]byte) { + acc := getAccount(t, testContext, accAddr) + keys := make([][]byte, numLeaves) + + firstKey := initDataTrie(t, testContext, acc) + keys[0] = firstKey + + dataTr := getAccountDataTrie(t, testContext, accAddr) + tr, ok := dataTr.(dataTrie) + require.True(t, ok) + + for i := 1; i < numLeaves; i++ { + key := keyGenerator(i) + err := tr.UpdateWithVersion(key, key, core.NotSpecified) + require.Nil(t, err) + + keys[i] = key + } + + rootHash := saveAccount(t, testContext, dataTr, acc) + + return rootHash, keys +} + +func initDataTrie( + t *testing.T, + testContext *vm.VMTestContext, + acc state.UserAccountHandler, +) []byte { + key := []byte("initDataTrieKey") + err := acc.SaveKeyValue(key, key) + require.Nil(t, err) + err = testContext.Accounts.SaveAccount(acc) + require.Nil(t, err) + + return key +} + +func saveAccount( + t *testing.T, + testContext *vm.VMTestContext, + dataTr common.Trie, + acc state.UserAccountHandler, +) []byte { + rootHash, _ := dataTr.RootHash() + acc.SetRootHash(rootHash) + + err := testContext.Accounts.SaveAccount(acc) + require.Nil(t, err) + + _, err = testContext.Accounts.Commit() + require.Nil(t, err) + + return rootHash +} + +func migrateDataTrie( + t *testing.T, + testContext *vm.VMTestContext, + sndAddr []byte, + gasPrice uint64, + gasLimit uint64, + expectedReturnCode vmcommon.ReturnCode, +) { + testContext.CleanIntermediateTransactions(t) + + gasLocked := "00" //use all available gas + txData := core.BuiltInFunctionMigrateDataTrie + "@" + gasLocked + + scr := &smartContractResult.SmartContractResult{ + Value: big.NewInt(0), + RcvAddr: sndAddr, + SndAddr: sndAddr, + Data: []byte(txData), + GasLimit: gasLimit, + GasPrice: gasPrice, + CallType: 1, + } + returnCode, errProcess := testContext.ScProcessor.ProcessSmartContractResult(scr) + if expectedReturnCode == vmcommon.Ok { + require.Nil(t, errProcess) + } else { + require.NotNil(t, errProcess) + } + + require.Equal(t, expectedReturnCode, returnCode) +} + +func testGasConsumed( + t *testing.T, + testContext *vm.VMTestContext, + gasLimit uint64, + expectedGasConsumed uint64, +) { + intermediate := testContext.GetIntermediateTransactions(t) + require.Equal(t, 1, len(intermediate)) + + gasConsumed := gasLimit - intermediate[0].GetGasLimit() + fmt.Println("gas consumed", gasConsumed) + fmt.Println("expected gas consumed", expectedGasConsumed) + require.Equal(t, expectedGasConsumed, gasConsumed) +} + +func getAccount(t *testing.T, testContext *vm.VMTestContext, scAddress []byte) state.UserAccountHandler { + scAcc, err := testContext.Accounts.LoadAccount(scAddress) + require.Nil(t, err) + acc, ok := scAcc.(state.UserAccountHandler) + require.True(t, ok) + + return acc +} + +func getAccountDataTrie(t *testing.T, testContext *vm.VMTestContext, address []byte) common.Trie { + acc := getAccount(t, testContext, address) + dataTrie, ok := acc.DataTrie().(common.Trie) + require.True(t, ok) + + return dataTrie +} diff --git a/integrationTests/vm/wasm/delegation/testRunner.go b/integrationTests/vm/wasm/delegation/testRunner.go index 3623f070858..6d9d94bae5c 100644 --- a/integrationTests/vm/wasm/delegation/testRunner.go +++ b/integrationTests/vm/wasm/delegation/testRunner.go @@ -53,7 +53,12 @@ func RunDelegationStressTest( MaxBatchSize: 45000, MaxOpenFiles: 10, } - persisterFactory := factory.NewPersisterFactory(dbConfig) + dbConfigHandler := factory.NewDBConfigHandler(dbConfig) + persisterFactory, err := factory.NewPersisterFactory(dbConfigHandler) + if err != nil { + return nil, err + } + tempDir, err := ioutil.TempDir("", "integrationTest") if err != nil { return nil, err diff --git a/integrationTests/vm/wasm/wasmvm/wasmVM_test.go b/integrationTests/vm/wasm/wasmvm/wasmVM_test.go index e8d4cdce9e3..5bc996b3bd6 100644 --- a/integrationTests/vm/wasm/wasmvm/wasmVM_test.go +++ b/integrationTests/vm/wasm/wasmvm/wasmVM_test.go @@ -33,6 +33,7 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/integrationtests" logger "github.com/multiversx/mx-chain-logger-go" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -504,6 +505,7 @@ func TestExecuteTransactionAndTimeToProcessChange(t *testing.T) { testHasher := sha256.NewSha256() shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) pubkeyConv, _ := pubkeyConverter.NewHexPubkeyConverter(32) + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} accnts := integrationtests.CreateInMemoryShardAccountsDB() esdtTransferParser, _ := parsers.NewESDTTransferParser(testMarshalizer) argsTxTypeHandler := coordinator.ArgNewTxTypeHandler{ @@ -512,7 +514,7 @@ func TestExecuteTransactionAndTimeToProcessChange(t *testing.T) { BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: enableEpochsHandler, } txTypeHandler, _ := coordinator.NewTxTypeHandler(argsTxTypeHandler) feeHandler := &economicsmocks.EconomicsHandlerStub{ @@ -542,7 +544,7 @@ func TestExecuteTransactionAndTimeToProcessChange(t *testing.T) { BadTxForwarder: &mock.IntermediateTransactionHandlerMock{}, ArgsParser: smartContract.NewArgumentParser(), ScrForwarder: &mock.IntermediateTransactionHandlerMock{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } txProc, _ := processTransaction.NewTxProcessor(argsNewTxProcessor) diff --git a/node/external/blockAPI/apiBlockFactory_test.go b/node/external/blockAPI/apiBlockFactory_test.go index 0ecd79b8b73..679aa6c0e1a 100644 --- a/node/external/blockAPI/apiBlockFactory_test.go +++ b/node/external/blockAPI/apiBlockFactory_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/state" storageMocks "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -39,7 +40,7 @@ func createMockArgsAPIBlockProc() *ArgAPIBlockProcessor { AlteredAccountsProvider: &testscommon.AlteredAccountsProviderStub{}, AccountsRepository: &state.AccountsRepositoryStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } diff --git a/node/external/blockAPI/internalBlock_test.go b/node/external/blockAPI/internalBlock_test.go index ef016d8122e..b653eaee42d 100644 --- a/node/external/blockAPI/internalBlock_test.go +++ b/node/external/blockAPI/internalBlock_test.go @@ -12,9 +12,10 @@ import ( "github.com/multiversx/mx-chain-go/node/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageMocks "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -50,7 +51,7 @@ func createMockInternalBlockProcessor( return false }, }, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) } @@ -63,13 +64,13 @@ func TestInternalBlockProcessor_ConvertShardBlockBytesToInternalBlockShouldFail( ibp := newInternalBlockProcessor( &ArgAPIBlockProcessor{ - Marshalizer: &testscommon.MarshalizerStub{ + Marshalizer: &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(_ interface{}, buff []byte) error { return expectedErr }, }, HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) wrongBytes := []byte{0, 1, 2} @@ -84,9 +85,9 @@ func TestInternalBlockProcessor_ConvertShardBlockBytesToInternalBlockShouldWork( ibp := newInternalBlockProcessor( &ArgAPIBlockProcessor{ - Marshalizer: &testscommon.MarshalizerMock{}, + Marshalizer: &marshallerMock.MarshalizerMock{}, HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) header := &block.Header{ @@ -331,13 +332,13 @@ func TestInternalBlockProcessor_ConvertMetaBlockBytesToInternalBlock_ShouldFail( ibp := newInternalBlockProcessor( &ArgAPIBlockProcessor{ - Marshalizer: &testscommon.MarshalizerStub{ + Marshalizer: &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(_ interface{}, buff []byte) error { return expectedErr }, }, HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) wrongBytes := []byte{0, 1, 2} @@ -352,9 +353,9 @@ func TestInternalBlockProcessor_ConvertMetaBlockBytesToInternalBlockShouldWork(t ibp := newInternalBlockProcessor( &ArgAPIBlockProcessor{ - Marshalizer: &testscommon.MarshalizerMock{}, + Marshalizer: &marshallerMock.MarshalizerMock{}, HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) header := &block.MetaBlock{ @@ -614,7 +615,7 @@ func TestInternalBlockProcessor_GetInternalMiniBlockByHash(t *testing.T) { }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) blk, err := ibp.GetInternalMiniBlock(common.ApiOutputFormatJSON, []byte("invalidHash"), 1) @@ -643,7 +644,7 @@ func TestInternalBlockProcessor_GetInternalMiniBlockByHash(t *testing.T) { }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) blk, err := ibp.GetInternalMiniBlock(common.ApiOutputFormatJSON, []byte("invalidHash"), 1) @@ -672,7 +673,7 @@ func TestInternalBlockProcessor_GetInternalMiniBlockByHash(t *testing.T) { }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) blk, err := ibp.GetInternalMiniBlock(common.ApiOutputFormatJSON, []byte("invalidHash"), 1) @@ -706,7 +707,7 @@ func TestInternalBlockProcessor_GetInternalMiniBlockByHash(t *testing.T) { }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) blk, err := ibp.GetInternalMiniBlock(common.ApiOutputFormatProto, miniBlockHash, 1) @@ -741,7 +742,7 @@ func TestInternalBlockProcessor_GetInternalMiniBlockByHash(t *testing.T) { }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) blk, err := ibp.GetInternalMiniBlock(common.ApiOutputFormatJSON, miniBlockHash, expEpoch) @@ -770,7 +771,7 @@ func TestInternalBlockProcessor_GetInternalStartOfEpochMetaBlock(t *testing.T) { Store: &storageMocks.ChainStorerStub{}, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) blk, err := ibp.GetInternalStartOfEpochMetaBlock(common.ApiOutputFormatJSON, expEpoch) @@ -799,7 +800,7 @@ func TestInternalBlockProcessor_GetInternalStartOfEpochMetaBlock(t *testing.T) { }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) blk, err := ibp.GetInternalStartOfEpochMetaBlock(common.ApiOutputFormatJSON, expEpoch) @@ -828,7 +829,7 @@ func TestInternalBlockProcessor_GetInternalStartOfEpochMetaBlock(t *testing.T) { }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) blk, err := ibp.GetInternalStartOfEpochMetaBlock(common.ApiOutputFormatProto, expEpoch) @@ -857,7 +858,7 @@ func TestInternalBlockProcessor_GetInternalStartOfEpochMetaBlock(t *testing.T) { }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, nil) blk, err := ibp.GetInternalStartOfEpochMetaBlock(common.ApiOutputFormatJSON, expEpoch) @@ -881,7 +882,7 @@ func TestInternalBlockProcessor_GetInternalStartOfEpochValidatorsInfo(t *testing Store: &storageMocks.ChainStorerStub{}, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, }, }, nil) @@ -912,7 +913,7 @@ func TestInternalBlockProcessor_GetInternalStartOfEpochValidatorsInfo(t *testing }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRefactorPeersMiniBlocksFlagEnabledField: true, }, }, nil) @@ -986,7 +987,7 @@ func TestInternalBlockProcessor_GetInternalStartOfEpochValidatorsInfo(t *testing }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ RefactorPeersMiniBlocksEnableEpochField: 5, }, }, nil) @@ -1071,7 +1072,7 @@ func TestInternalBlockProcessor_GetInternalStartOfEpochValidatorsInfo(t *testing }, Uint64ByteSliceConverter: mock.NewNonceHashConverterMock(), HistoryRepo: &dblookupext.HistoryRepositoryStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ RefactorPeersMiniBlocksEnableEpochField: 5, }, }, nil) diff --git a/node/external/blockAPI/metaBlock_test.go b/node/external/blockAPI/metaBlock_test.go index b3cbb6d4ffb..256c9e922ec 100644 --- a/node/external/blockAPI/metaBlock_test.go +++ b/node/external/blockAPI/metaBlock_test.go @@ -21,6 +21,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/state" storageMocks "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" @@ -828,7 +829,7 @@ func TestMetaAPIBlockProcessor_GetAlteredAccountsForBlock(t *testing.T) { t.Run("get altered account by block hash - should work", func(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} headerHash := []byte("d08089f2ab739520598fd7aeed08c427460fe94f286383047f3f61951afc4e00") mbHash := []byte("mb-hash") txHash0, txHash1 := []byte("tx-hash-0"), []byte("tx-hash-1") diff --git a/node/external/blockAPI/shardBlock_test.go b/node/external/blockAPI/shardBlock_test.go index 5dd2f2bf8fb..2af941eea9b 100644 --- a/node/external/blockAPI/shardBlock_test.go +++ b/node/external/blockAPI/shardBlock_test.go @@ -19,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/state" storageMocks "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" @@ -597,7 +598,7 @@ func TestShardAPIBlockProcessor_GetAlteredAccountsForBlock(t *testing.T) { t.Run("get altered account by block hash - should work", func(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} headerHash := []byte("d08089f2ab739520598fd7aeed08c427460fe94f286383047f3f61951afc4e00") mbHash := []byte("mb-hash") txHash0, txHash1 := []byte("tx-hash-0"), []byte("tx-hash-1") diff --git a/node/external/logs/logsFacade_test.go b/node/external/logs/logsFacade_test.go index 21d11f99c59..faa06d7fbdf 100644 --- a/node/external/logs/logsFacade_test.go +++ b/node/external/logs/logsFacade_test.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/require" ) @@ -15,7 +16,7 @@ func TestNewLogsFacade(t *testing.T) { t.Run("NilStorageService", func(t *testing.T) { arguments := ArgsNewLogsFacade{ StorageService: nil, - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, PubKeyConverter: testscommon.NewPubkeyConverterMock(32), } @@ -41,7 +42,7 @@ func TestNewLogsFacade(t *testing.T) { t.Run("NilPubKeyConverter", func(t *testing.T) { arguments := ArgsNewLogsFacade{ StorageService: genericMocks.NewChainStorerMock(7), - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, PubKeyConverter: nil, } diff --git a/node/external/logs/logsRepository_test.go b/node/external/logs/logsRepository_test.go index 4e7fda4e146..030fcef27ca 100644 --- a/node/external/logs/logsRepository_test.go +++ b/node/external/logs/logsRepository_test.go @@ -9,8 +9,8 @@ import ( storageCore "github.com/multiversx/mx-chain-core-go/storage" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/require" ) @@ -25,13 +25,13 @@ func TestNewLogsRepository(t *testing.T) { GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { return nil, errors.New("new error") }, - }, testscommon.MarshalizerMock{}) + }, marshallerMock.MarshalizerMock{}) require.Nil(t, repository) }) t.Run("should work", func(t *testing.T) { t.Parallel() - repository := newLogsRepository(&genericMocks.ChainStorerMock{}, testscommon.MarshalizerMock{}) + repository := newLogsRepository(&genericMocks.ChainStorerMock{}, marshallerMock.MarshalizerMock{}) require.NotNil(t, repository) }) } diff --git a/node/external/transactionAPI/apiTransactionResults_test.go b/node/external/transactionAPI/apiTransactionResults_test.go index 7676ade83ad..70b4237edc6 100644 --- a/node/external/transactionAPI/apiTransactionResults_test.go +++ b/node/external/transactionAPI/apiTransactionResults_test.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dbLookupExtMock "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" datafield "github.com/multiversx/mx-chain-vm-common-go/parsers/dataField" "github.com/stretchr/testify/require" @@ -105,8 +106,8 @@ func TestApiTransactionProcessor_PutResultsInTransactionWhenNoResultsShouldWork( testscommon.RealWorldBech32PubkeyConverter, historyRepo, genericMocks.NewChainStorerMock(epoch), - &testscommon.MarshalizerMock{}, - newTransactionUnmarshaller(&testscommon.MarshalizerMock{}, testscommon.RealWorldBech32PubkeyConverter, dataFieldParser, shardCoordinator), + &marshallerMock.MarshalizerMock{}, + newTransactionUnmarshaller(&marshallerMock.MarshalizerMock{}, testscommon.RealWorldBech32PubkeyConverter, dataFieldParser, shardCoordinator), &testscommon.LogsFacadeStub{}, shardCoordinator, dataFieldParser, diff --git a/node/interface.go b/node/interface.go index 6dd9517d233..23a706ed25a 100644 --- a/node/interface.go +++ b/node/interface.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/update" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) // NetworkShardingCollector defines the updating methods used by the network sharding component @@ -56,3 +57,8 @@ type HealthService interface { io.Closer RegisterComponent(component interface{}) } + +type accountHandlerWithDataTrieMigrationStatus interface { + vmcommon.AccountHandler + IsDataTrieMigrated() (bool, error) +} diff --git a/node/mock/accountFactoryStub.go b/node/mock/accountFactoryStub.go deleted file mode 100644 index 2912f6fedd1..00000000000 --- a/node/mock/accountFactoryStub.go +++ /dev/null @@ -1,18 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-vm-common-go" - -// AccountsFactoryStub - -type AccountsFactoryStub struct { - CreateAccountCalled func(address []byte) (vmcommon.AccountHandler, error) -} - -// CreateAccount - -func (afs *AccountsFactoryStub) CreateAccount(address []byte) (vmcommon.AccountHandler, error) { - return afs.CreateAccountCalled(address) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (afs *AccountsFactoryStub) IsInterfaceNil() bool { - return afs == nil -} diff --git a/node/node.go b/node/node.go index 1d5cc001adb..b08a4b8b925 100644 --- a/node/node.go +++ b/node/node.go @@ -38,7 +38,6 @@ import ( procTx "github.com/multiversx/mx-chain-go/process/transaction" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/trie" - "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" logger "github.com/multiversx/mx-chain-logger-go" @@ -220,16 +219,11 @@ func (n *Node) GetAllIssuedESDTs(tokenType string, ctx context.Context) ([]strin return tokens, nil } - rootHash, err := userAccount.DataTrie().RootHash() - if err != nil { - return nil, err - } - chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) + err = userAccount.GetAllLeaves(chLeaves, ctx) if err != nil { return nil, err } @@ -245,7 +239,7 @@ func (n *Node) GetAllIssuedESDTs(tokenType string, ctx context.Context) ([]strin continue } - esdtToken, okGet := n.getEsdtDataFromLeaf(leaf, userAccount) + esdtToken, okGet := n.getEsdtDataFromLeaf(leaf) if !okGet { continue } @@ -267,16 +261,10 @@ func (n *Node) GetAllIssuedESDTs(tokenType string, ctx context.Context) ([]strin return tokens, nil } -func (n *Node) getEsdtDataFromLeaf(leaf core.KeyValueHolder, userAccount state.UserAccountHandler) (*systemSmartContracts.ESDTDataV2, bool) { +func (n *Node) getEsdtDataFromLeaf(leaf core.KeyValueHolder) (*systemSmartContracts.ESDTDataV2, bool) { esdtToken := &systemSmartContracts.ESDTDataV2{} - suffix := append(leaf.Key(), userAccount.AddressBytes()...) - value, errVal := leaf.ValueWithoutSuffix(suffix) - if errVal != nil { - log.Warn("cannot get value without suffix", "error", errVal, "key", leaf.Key()) - return nil, false - } - err := n.coreComponents.InternalMarshalizer().Unmarshal(esdtToken, value) + err := n.coreComponents.InternalMarshalizer().Unmarshal(esdtToken, leaf.Value()) if err != nil { log.Warn("cannot unmarshal esdt data", "err", err) return nil, false @@ -301,30 +289,18 @@ func (n *Node) GetKeyValuePairs(address string, options api.AccountQueryOptions, return map[string]string{}, api.BlockInfo{}, nil } - rootHash, err := userAccount.DataTrie().RootHash() - if err != nil { - return nil, api.BlockInfo{}, err - } - chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) + err = userAccount.GetAllLeaves(chLeaves, ctx) if err != nil { return nil, api.BlockInfo{}, err } mapToReturn := make(map[string]string) for leaf := range chLeaves.LeavesChan { - suffix := append(leaf.Key(), userAccount.AddressBytes()...) - value, errVal := leaf.ValueWithoutSuffix(suffix) - if errVal != nil { - log.Warn("cannot get value without suffix", "error", errVal, "key", leaf.Key()) - continue - } - - mapToReturn[hex.EncodeToString(leaf.Key())] = hex.EncodeToString(value) + mapToReturn[hex.EncodeToString(leaf.Key())] = hex.EncodeToString(leaf.Value()) } err = chLeaves.ErrChan.ReadFromChanNonBlocking() @@ -475,16 +451,11 @@ func (n *Node) getTokensIDsWithFilter( return tokens, api.BlockInfo{}, nil } - rootHash, err := userAccount.DataTrie().RootHash() - if err != nil { - return nil, api.BlockInfo{}, err - } - chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) + err = userAccount.GetAllLeaves(chLeaves, ctx) if err != nil { return nil, api.BlockInfo{}, err } @@ -495,7 +466,7 @@ func (n *Node) getTokensIDsWithFilter( continue } - esdtToken, okGet := n.getEsdtDataFromLeaf(leaf, userAccount) + esdtToken, okGet := n.getEsdtDataFromLeaf(leaf) if !okGet { continue } @@ -617,16 +588,11 @@ func (n *Node) GetAllESDTTokens(address string, options api.AccountQueryOptions, esdtPrefix := []byte(core.ProtectedKeyPrefix + core.ESDTKeyIdentifier) lenESDTPrefix := len(esdtPrefix) - rootHash, err := userAccount.DataTrie().RootHash() - if err != nil { - return nil, api.BlockInfo{}, err - } - chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) + err = userAccount.GetAllLeaves(chLeaves, ctx) if err != nil { return nil, api.BlockInfo{}, err } @@ -1369,9 +1335,13 @@ func (n *Node) GetProofDataTrie(rootHash string, address string, key string) (*c return nil, nil, err } - dataTrieProofResponse, err := n.getProof(dataTrieRootHash, keyBytes) + dataTrieKey := n.coreComponents.Hasher().Compute(string(keyBytes)) + dataTrieProofResponse, err := n.getProof(dataTrieRootHash, dataTrieKey) if err != nil { - return nil, nil, err + dataTrieProofResponse, err = n.getProof(dataTrieRootHash, keyBytes) + if err != nil { + return nil, nil, err + } } dataTrieProofResponse.Value = value @@ -1399,6 +1369,21 @@ func (n *Node) VerifyProof(rootHash string, address string, proof [][]byte) (boo return mpv.VerifyProof(rootHashBytes, key, proof) } +// IsDataTrieMigrated returns true if the data trie for the given address is migrated +func (n *Node) IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) { + accountHandler, _, err := n.loadUserAccountHandlerByAddress(address, options) + if err != nil { + return false, err + } + + acc, ok := accountHandler.(accountHandlerWithDataTrieMigrationStatus) + if !ok { + return false, fmt.Errorf("wrong type assertion for address %s, account type %T", address, accountHandler) + } + + return acc.IsDataTrieMigrated() +} + func (n *Node) getRootHashAndAddressAsBytes(rootHash string, address string) ([]byte, []byte, error) { rootHashBytes, err := hex.DecodeString(rootHash) if err != nil { diff --git a/node/nodeLoadAccounts_test.go b/node/nodeLoadAccounts_test.go index 8acb60a3c56..e7e03c2d05b 100644 --- a/node/nodeLoadAccounts_test.go +++ b/node/nodeLoadAccounts_test.go @@ -24,8 +24,8 @@ import ( func TestNode_GetAccountWithOptionsShouldWork(t *testing.T) { t.Parallel() - alice, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) - alice.Balance = big.NewInt(100) + alice := createAcc(testscommon.TestPubKeyAlice) + _ = alice.AddToBalance(big.NewInt(100)) accountsRepostitory := &mockState.AccountsRepositoryStub{} accountsRepostitory.GetAccountWithBlockInfoCalled = func(pubkey []byte, options api.AccountQueryOptions) (vmcommon.AccountHandler, common.BlockInfo, error) { diff --git a/node/nodeRunner.go b/node/nodeRunner.go index db89809d43b..024a2e19020 100644 --- a/node/nodeRunner.go +++ b/node/nodeRunner.go @@ -681,6 +681,7 @@ func getBaseAccountSyncerArgs( CheckNodesOnDisk: true, UserAccountsSyncStatisticsHandler: trieStatistics.NewTrieSyncStatistics(), AppStatusHandler: disabled.NewAppStatusHandler(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } } diff --git a/node/node_test.go b/node/node_test.go index a22e98e1b58..1a15b79cb39 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -49,10 +49,13 @@ import ( dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" factoryTests "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/guardianMocks" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/mainFactoryMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" @@ -90,10 +93,21 @@ func createMockPubkeyConverter() *testscommon.PubkeyConverterMock { return testscommon.NewPubkeyConverterMock(32) } +func createAcc(address []byte) state.UserAccountHandler { + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + acc, _ := state.NewUserAccount(address, argsAccCreation) + + return acc +} + func getAccAdapter(balance *big.Int) *stateMock.AccountsStub { accDB := &stateMock.AccountsStub{} accDB.GetExistingAccountCalled = func(address []byte) (handler vmcommon.AccountHandler, e error) { - acc, _ := state.NewUserAccount(address) + acc := createAcc(address) _ = acc.AddToBalance(balance) acc.IncreaseNonce(1) @@ -257,7 +271,12 @@ func TestNode_GetBalanceAccNotFoundShouldReturnEmpty(t *testing.T) { func TestGetBalance(t *testing.T) { t.Parallel() - testAccount, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + testAccount, _ := state.NewUserAccount(testscommon.TestPubKeyAlice, argsAccCreation) testAccount.Balance = big.NewInt(100) accountsRepository := &stateMock.AccountsRepositoryStub{ @@ -290,7 +309,12 @@ func TestGetUsername(t *testing.T) { expectedUsername := []byte("elrond") - testAccount, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + testAccount, _ := state.NewUserAccount(testscommon.TestPubKeyAlice, argsAccCreation) testAccount.UserName = expectedUsername accountsRepository := &stateMock.AccountsRepositoryStub{ GetAccountWithBlockInfoCalled: func(address []byte, options api.AccountQueryOptions) (vmcommon.AccountHandler, common.BlockInfo, error) { @@ -361,7 +385,12 @@ func TestGetCodeHash(t *testing.T) { expectedCodeHash := []byte("hash") - testAccount, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + testAccount, _ := state.NewUserAccount(testscommon.TestPubKeyAlice, argsAccCreation) testAccount.CodeHash = expectedCodeHash accountsRepository := &stateMock.AccountsRepositoryStub{ GetAccountWithBlockInfoCalled: func(address []byte, options api.AccountQueryOptions) (vmcommon.AccountHandler, common.BlockInfo, error) { @@ -430,7 +459,7 @@ func TestNode_GetKeyValuePairsAccNotFoundShouldReturnEmpty(t *testing.T) { func TestNode_GetKeyValuePairs(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount([]byte("newaddress")) + acc := createAcc([]byte("newaddress")) k1, v1 := []byte("key1"), []byte("value1") k2, v2 := []byte("key2"), []byte("value2") @@ -438,14 +467,14 @@ func TestNode_GetKeyValuePairs(t *testing.T) { accDB := &stateMock.AccountsStub{} acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { go func() { suffix := append(k1, acc.AddressBytes()...) - trieLeaf := keyValStorage.NewKeyValStorage(k1, append(v1, suffix...)) + trieLeaf, _ := tlp.ParseLeaf(k1, append(v1, suffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf suffix = append(k2, acc.AddressBytes()...) - trieLeaf2 := keyValStorage.NewKeyValStorage(k2, append(v2, suffix...)) + trieLeaf2, _ := tlp.ParseLeaf(k2, append(v2, suffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf2 close(leavesChannels.LeavesChan) leavesChannels.ErrChan.Close() @@ -498,14 +527,14 @@ func TestNode_GetKeyValuePairs(t *testing.T) { func TestNode_GetKeyValuePairs_GetAllLeavesShouldFail(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount([]byte("newaddress")) + acc := createAcc([]byte("newaddress")) accDB := &stateMock.AccountsStub{} expectedErr := errors.New("expected err") acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { go func() { leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) close(leavesChannels.LeavesChan) @@ -553,12 +582,12 @@ func TestNode_GetKeyValuePairs_GetAllLeavesShouldFail(t *testing.T) { func TestNode_GetKeyValuePairsContextShouldTimeout(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount([]byte("newaddress")) + acc := createAcc([]byte("newaddress")) accDB := &stateMock.AccountsStub{} acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { go func() { time.Sleep(time.Second) close(leavesChannels.LeavesChan) @@ -648,7 +677,7 @@ func TestNode_GetValueForKeyAccNotFoundShouldReturnEmpty(t *testing.T) { func TestNode_GetValueForKey(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount([]byte("newaddress")) + acc := createAcc([]byte("newaddress")) k1, v1 := []byte("key1"), []byte("value1") _ = acc.SaveKeyValue(k1, v1) @@ -730,7 +759,7 @@ func TestNode_GetESDTDataAccNotFoundShouldReturnEmpty(t *testing.T) { func TestNode_GetESDTData(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) + acc := createAcc(testscommon.TestPubKeyAlice) esdtToken := "newToken" esdtData := &esdt.ESDigitalToken{Value: big.NewInt(10)} @@ -779,7 +808,7 @@ func TestNode_GetESDTData(t *testing.T) { func TestNode_GetESDTDataForNFT(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) + acc := createAcc(testscommon.TestPubKeyAlice) esdtToken := "newToken" nonce := int64(100) @@ -824,7 +853,7 @@ func TestNode_GetESDTDataForNFT(t *testing.T) { func TestNode_GetAllESDTTokens(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) + acc := createAcc(testscommon.TestPubKeyAlice) esdtToken := "newToken" esdtKey := []byte(core.ProtectedKeyPrefix + core.ESDTKeyIdentifier + esdtToken) @@ -838,7 +867,7 @@ func TestNode_GetAllESDTTokens(t *testing.T) { acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { go func() { trieLeaf := keyValStorage.NewKeyValStorage(esdtKey, nil) leavesChannels.LeavesChan <- trieLeaf @@ -892,12 +921,12 @@ func TestNode_GetAllESDTTokens(t *testing.T) { func TestNode_GetAllESDTTokens_GetAllLeavesShouldFail(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) + acc := createAcc(testscommon.TestPubKeyAlice) expectedErr := errors.New("expected error") acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { go func() { leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) close(leavesChannels.LeavesChan) @@ -948,11 +977,11 @@ func TestNode_GetAllESDTTokens_GetAllLeavesShouldFail(t *testing.T) { func TestNode_GetAllESDTTokensContextShouldTimeout(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) + acc := createAcc(testscommon.TestPubKeyAlice) acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { go func() { time.Sleep(time.Second) close(leavesChannels.LeavesChan) @@ -1044,7 +1073,7 @@ func TestNode_GetAllESDTsAccNotFoundShouldReturnEmpty(t *testing.T) { func TestNode_GetAllESDTTokensShouldReturnEsdtAndFormattedNft(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(testscommon.TestPubKeyAlice) + acc := createAcc(testscommon.TestPubKeyAlice) esdtToken := "TKKR-7q8w9e" esdtKey := []byte(core.ProtectedKeyPrefix + core.ESDTKeyIdentifier + esdtToken) @@ -1077,7 +1106,7 @@ func TestNode_GetAllESDTTokensShouldReturnEsdtAndFormattedNft(t *testing.T) { } acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { wg := &sync.WaitGroup{} wg.Add(1) go func() { @@ -1140,7 +1169,7 @@ func TestNode_GetAllESDTTokensShouldReturnEsdtAndFormattedNft(t *testing.T) { func TestNode_GetAllIssuedESDTs(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount([]byte("newaddress")) + acc := createAcc([]byte("newaddress")) esdtToken := []byte("TCK-RANDOM") sftToken := []byte("SFT-RANDOM") nftToken := []byte("NFT-RANDOM") @@ -1163,15 +1192,15 @@ func TestNode_GetAllIssuedESDTs(t *testing.T) { acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { go func() { - trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf - trieLeaf = keyValStorage.NewKeyValStorage(sftToken, append(sftMarshalledData, sftSuffix...)) + trieLeaf, _ = tlp.ParseLeaf(sftToken, append(sftMarshalledData, sftSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf - trieLeaf = keyValStorage.NewKeyValStorage(nftToken, append(nftMarshalledData, nftSuffix...)) + trieLeaf, _ = tlp.ParseLeaf(nftToken, append(nftMarshalledData, nftSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) leavesChannels.ErrChan.Close() @@ -1237,7 +1266,7 @@ func TestNode_GetESDTsWithRole(t *testing.T) { t.Parallel() addrBytes := testscommon.TestPubKeyAlice - acc, _ := state.NewUserAccount(addrBytes) + acc := createAcc(addrBytes) esdtToken := []byte("TCK-RANDOM") specialRoles := []*systemSmartContracts.ESDTRoles{ @@ -1255,9 +1284,9 @@ func TestNode_GetESDTsWithRole(t *testing.T) { acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { go func() { - trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) leavesChannels.ErrChan.Close() @@ -1317,7 +1346,7 @@ func TestNode_GetESDTsRoles(t *testing.T) { t.Parallel() addrBytes := testscommon.TestPubKeyAlice - acc, _ := state.NewUserAccount(addrBytes) + acc := createAcc(addrBytes) esdtToken := []byte("TCK-RANDOM") specialRoles := []*systemSmartContracts.ESDTRoles{ @@ -1335,9 +1364,9 @@ func TestNode_GetESDTsRoles(t *testing.T) { acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { go func() { - trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) leavesChannels.ErrChan.Close() @@ -1389,7 +1418,7 @@ func TestNode_GetNFTTokenIDsRegisteredByAddress(t *testing.T) { t.Parallel() addrBytes := testscommon.TestPubKeyAlice - acc, _ := state.NewUserAccount(addrBytes) + acc := createAcc(addrBytes) esdtToken := []byte("TCK-RANDOM") esdtData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("fungible"), TokenType: []byte(core.SemiFungibleESDT), OwnerAddress: addrBytes} @@ -1400,9 +1429,9 @@ func TestNode_GetNFTTokenIDsRegisteredByAddress(t *testing.T) { acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { go func() { - trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) leavesChannels.ErrChan.Close() @@ -1454,11 +1483,11 @@ func TestNode_GetNFTTokenIDsRegisteredByAddressContextShouldTimeout(t *testing.T t.Parallel() addrBytes := testscommon.TestPubKeyAlice - acc, _ := state.NewUserAccount(addrBytes) + acc := createAcc(addrBytes) acc.SetDataTrie( &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { go func() { time.Sleep(time.Second) close(leavesChannels.LeavesChan) @@ -1625,7 +1654,7 @@ func TestGenerateTransaction_GetAccountReturnsNilShouldWork(t *testing.T) { accAdapter := &stateMock.AccountsStub{ GetExistingAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount(address) + return createAcc(address), nil }, } privateKey := getPrivateKey() @@ -1772,7 +1801,7 @@ func TestGenerateTransaction_ShouldSetCorrectNonce(t *testing.T) { nonce := uint64(7) accAdapter := &stateMock.AccountsStub{ GetExistingAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { - acc, _ := state.NewUserAccount(address) + acc := createAcc(address) _ = acc.AddToBalance(big.NewInt(0)) acc.IncreaseNonce(nonce) @@ -2740,7 +2769,7 @@ func TestCreateTransaction_OkValsShouldWork(t *testing.T) { stateComponents := getDefaultStateComponents() stateComponents.AccountsAPI = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount([]byte("address")) + return createAcc([]byte("address")), nil }, } @@ -3368,7 +3397,7 @@ func TestNode_GetAccountAccNotFoundShouldReturnEmpty(t *testing.T) { func TestNode_GetAccountAccountExistsShouldReturn(t *testing.T) { t.Parallel() - accnt, _ := state.NewUserAccount(testscommon.TestPubKeyBob) + accnt := createAcc(testscommon.TestPubKeyBob) _ = accnt.AddToBalance(big.NewInt(1)) accnt.IncreaseNonce(2) accnt.SetRootHash([]byte("root hash")) @@ -4102,7 +4131,7 @@ func TestNode_GetProofDataTrieShouldWork(t *testing.T) { return dataTrieProof, dataTrieValue, nil } - return nil, nil, nil + return nil, nil, fmt.Errorf("key not found") }, }, nil }, @@ -4186,6 +4215,118 @@ func TestNode_VerifyProof(t *testing.T) { assert.Nil(t, err) } +func TestNode_IsDataTrieMigrated(t *testing.T) { + t.Parallel() + + t.Run("invalid address", func(t *testing.T) { + t.Parallel() + + n, _ := node.NewNode( + node.WithStateComponents(getDefaultStateComponents()), + node.WithCoreComponents(getDefaultCoreComponents()), + ) + + isMigrated, err := n.IsDataTrieMigrated("invalid address", api.AccountQueryOptions{}) + assert.False(t, isMigrated) + assert.NotNil(t, err) + }) + + t.Run("load account err", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("load account error") + stateComponents := getDefaultStateComponents() + stateComponents.AccountsRepo = &stateMock.AccountsRepositoryStub{ + GetAccountWithBlockInfoCalled: func(_ []byte, _ api.AccountQueryOptions) (vmcommon.AccountHandler, common.BlockInfo, error) { + return nil, nil, expectedErr + }, + } + + n, _ := node.NewNode( + node.WithStateComponents(stateComponents), + node.WithCoreComponents(getDefaultCoreComponents()), + ) + + isMigrated, err := n.IsDataTrieMigrated("erd1qqqqqqqqqqqqqqqpqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqplllst77y4l", api.AccountQueryOptions{}) + assert.False(t, isMigrated) + assert.Equal(t, expectedErr, err) + }) + + t.Run("wrong type assertion", func(t *testing.T) { + t.Parallel() + + stateComponents := getDefaultStateComponents() + stateComponents.AccountsRepo = &stateMock.AccountsRepositoryStub{ + GetAccountWithBlockInfoCalled: func(_ []byte, _ api.AccountQueryOptions) (vmcommon.AccountHandler, common.BlockInfo, error) { + return &stateMock.AccountWrapMock{}, nil, nil + }, + } + + n, _ := node.NewNode( + node.WithStateComponents(stateComponents), + node.WithCoreComponents(getDefaultCoreComponents()), + ) + + isMigrated, err := n.IsDataTrieMigrated("erd1qqqqqqqqqqqqqqqpqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqplllst77y4l", api.AccountQueryOptions{}) + assert.False(t, isMigrated) + assert.True(t, strings.Contains(err.Error(), "wrong type assertion")) + }) + + t.Run("should work and return false", func(t *testing.T) { + t.Parallel() + + acc := createAcc([]byte("000000000000000000010000000000000000000000000000000000000001ffff")) + acc.SetDataTrie(&trieMock.TrieStub{ + IsMigratedToLatestVersionCalled: func() (bool, error) { + return false, nil + }, + }) + + stateComponents := getDefaultStateComponents() + stateComponents.AccountsRepo = &stateMock.AccountsRepositoryStub{ + GetAccountWithBlockInfoCalled: func(_ []byte, _ api.AccountQueryOptions) (vmcommon.AccountHandler, common.BlockInfo, error) { + return acc, nil, nil + }, + } + + n, _ := node.NewNode( + node.WithStateComponents(stateComponents), + node.WithCoreComponents(getDefaultCoreComponents()), + ) + + isMigrated, err := n.IsDataTrieMigrated("erd1qqqqqqqqqqqqqqqpqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqplllst77y4l", api.AccountQueryOptions{}) + assert.False(t, isMigrated) + assert.Nil(t, err) + }) + + t.Run("should work and return true", func(t *testing.T) { + t.Parallel() + + acc := createAcc([]byte("000000000000000000010000000000000000000000000000000000000001ffff")) + acc.SetDataTrie(&trieMock.TrieStub{ + IsMigratedToLatestVersionCalled: func() (bool, error) { + return true, nil + }, + }) + + stateComponents := getDefaultStateComponents() + stateComponents.AccountsRepo = &stateMock.AccountsRepositoryStub{ + GetAccountWithBlockInfoCalled: func(_ []byte, _ api.AccountQueryOptions) (vmcommon.AccountHandler, common.BlockInfo, error) { + return acc, nil, nil + }, + } + + n, _ := node.NewNode( + node.WithStateComponents(stateComponents), + node.WithCoreComponents(getDefaultCoreComponents()), + ) + + isMigrated, err := n.IsDataTrieMigrated("erd1qqqqqqqqqqqqqqqpqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqplllst77y4l", api.AccountQueryOptions{}) + assert.True(t, isMigrated) + assert.Nil(t, err) + }) +} + func TestGetESDTSupplyError(t *testing.T) { t.Parallel() @@ -4591,7 +4732,12 @@ func TestNode_setTxGuardianData(t *testing.T) { func TestNode_GetGuardianData(t *testing.T) { userAddressBytes := bytes.Repeat([]byte{3}, 32) - testAccount, _ := state.NewUserAccount(userAddressBytes) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + testAccount, _ := state.NewUserAccount(userAddressBytes, argsAccCreation) testAccountsDB := &stateMock.AccountsStub{ GetAccountWithBlockInfoCalled: func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) { return testAccount, nil, nil @@ -4777,7 +4923,7 @@ func TestNode_GetGuardianData(t *testing.T) { require.Nil(t, err) }) t.Run("one active and one pending and account guarded", func(t *testing.T) { - acc, _ := state.NewUserAccount(userAddressBytes) + acc, _ := state.NewUserAccount(userAddressBytes, argsAccCreation) acc.CodeMetadata = (&vmcommon.CodeMetadata{Guarded: true}).ToBytes() accDB := &stateMock.AccountsStub{ GetAccountWithBlockInfoCalled: func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) { @@ -4929,9 +5075,9 @@ func TestNode_getPendingAndActiveGuardians(t *testing.T) { func getDefaultCoreComponents() *nodeMockFactory.CoreComponentsMock { return &nodeMockFactory.CoreComponentsMock{ - IntMarsh: &testscommon.MarshalizerMock{}, - TxMarsh: &testscommon.MarshalizerMock{}, - VmMarsh: &testscommon.MarshalizerMock{}, + IntMarsh: &marshallerMock.MarshalizerMock{}, + TxMarsh: &marshallerMock.MarshalizerMock{}, + VmMarsh: &marshallerMock.MarshalizerMock{}, TxSignHasherField: &testscommon.HasherStub{}, Hash: &testscommon.HasherStub{}, UInt64ByteSliceConv: testscommon.NewNonceHashConverterMock(), diff --git a/node/trieIterators/delegatedListProcessor.go b/node/trieIterators/delegatedListProcessor.go index cf257a79e4b..db83f3d1b92 100644 --- a/node/trieIterators/delegatedListProcessor.go +++ b/node/trieIterators/delegatedListProcessor.go @@ -14,7 +14,6 @@ import ( "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/vm" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) @@ -133,16 +132,11 @@ func (dlp *delegatedListProcessor) getDelegatorsList(delegationSC []byte, ctx co return nil, fmt.Errorf("%w for delegationSC %s", err, hex.EncodeToString(delegationSC)) } - rootHash, err := delegatorAccount.DataTrie().RootHash() - if err != nil { - return nil, fmt.Errorf("%w for delegationSC %s", err, hex.EncodeToString(delegationSC)) - } - chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = delegatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) + err = delegatorAccount.GetAllLeaves(chLeaves, ctx) if err != nil { return nil, err } diff --git a/node/trieIterators/delegatedListProcessor_test.go b/node/trieIterators/delegatedListProcessor_test.go index c240e0a4b29..97eea67188b 100644 --- a/node/trieIterators/delegatedListProcessor_test.go +++ b/node/trieIterators/delegatedListProcessor_test.go @@ -9,16 +9,12 @@ import ( "testing" "time" - "github.com/multiversx/mx-chain-core-go/core/keyValStorage" "github.com/multiversx/mx-chain-core-go/data/api" - "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/node/mock" "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" - trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -119,7 +115,7 @@ func TestDelegatedListProc_GetDelegatorsListContextShouldTimeout(t *testing.T) { } arg.Accounts.AccountsAdapter = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - return createDelegationScAccount(addressContainer, delegators, addressContainer, time.Second), nil + return createScAccount(addressContainer, delegators, addressContainer, time.Second), nil }, RecreateTrieCalled: func(rootHash []byte) error { return nil @@ -166,7 +162,7 @@ func TestDelegatedListProc_GetDelegatorsListShouldWork(t *testing.T) { } arg.Accounts.AccountsAdapter = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - return createDelegationScAccount(addressContainer, delegators, addressContainer, 0), nil + return createScAccount(addressContainer, delegators, addressContainer, 0), nil }, RecreateTrieCalled: func(rootHash []byte) error { return nil @@ -234,28 +230,3 @@ func TestDelegatedListProcessor_IsInterfaceNil(t *testing.T) { dlp, _ = NewDelegatedListProcessor(createMockArgs()) require.False(t, dlp.IsInterfaceNil()) } - -func createDelegationScAccount(address []byte, leaves [][]byte, rootHash []byte, timeSleep time.Duration) state.UserAccountHandler { - acc, _ := state.NewUserAccount(address) - acc.SetDataTrie(&trieMock.TrieStub{ - RootCalled: func() ([]byte, error) { - return rootHash, nil - }, - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { - go func() { - time.Sleep(timeSleep) - for _, leafBuff := range leaves { - leaf := keyValStorage.NewKeyValStorage(leafBuff, nil) - leavesChannels.LeavesChan <- leaf - } - - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() - - return nil - }, - }) - - return acc -} diff --git a/node/trieIterators/directStakedListProcessor.go b/node/trieIterators/directStakedListProcessor.go index 7193b5de2de..ce5f415396b 100644 --- a/node/trieIterators/directStakedListProcessor.go +++ b/node/trieIterators/directStakedListProcessor.go @@ -9,7 +9,6 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/vm" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -53,16 +52,11 @@ func (dslp *directStakedListProcessor) GetDirectStakedList(ctx context.Context) } func (dslp *directStakedListProcessor) getAllStakedAccounts(validatorAccount state.UserAccountHandler, ctx context.Context) ([]*api.DirectStakedValue, error) { - rootHash, err := validatorAccount.DataTrie().RootHash() - if err != nil { - return nil, err - } - chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = validatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) + err := validatorAccount.GetAllLeaves(chLeaves, ctx) if err != nil { return nil, err } diff --git a/node/trieIterators/directStakedListProcessor_test.go b/node/trieIterators/directStakedListProcessor_test.go index 29398b7bcb6..8b4aa932edb 100644 --- a/node/trieIterators/directStakedListProcessor_test.go +++ b/node/trieIterators/directStakedListProcessor_test.go @@ -16,6 +16,9 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -70,7 +73,7 @@ func TestDirectStakedListProc_GetDelegatorsListContextShouldTimeout(t *testing.T } arg.Accounts.AccountsAdapter = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - return createValidatorScAccount(addressContainer, validators, addressContainer, time.Second), nil + return createScAccount(addressContainer, validators, addressContainer, time.Second), nil }, RecreateTrieCalled: func(rootHash []byte) error { return nil @@ -114,7 +117,7 @@ func TestDirectStakedListProc_GetDelegatorsListShouldWork(t *testing.T) { } arg.Accounts.AccountsAdapter = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - return createValidatorScAccount(addressContainer, validators, addressContainer, 0), nil + return createScAccount(addressContainer, validators, addressContainer, 0), nil }, RecreateTrieCalled: func(rootHash []byte) error { return nil @@ -146,13 +149,18 @@ func TestDirectStakedListProc_GetDelegatorsListShouldWork(t *testing.T) { assert.Equal(t, []*api.DirectStakedValue{&expectedDirectStake1, &expectedDirectStake2}, directStakedList) } -func createValidatorScAccount(address []byte, leaves [][]byte, rootHash []byte, timeSleep time.Duration) state.UserAccountHandler { - acc, _ := state.NewUserAccount(address) +func createScAccount(address []byte, leaves [][]byte, rootHash []byte, timeSleep time.Duration) state.UserAccountHandler { + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + acc, _ := state.NewUserAccount(address, argsAccCreation) acc.SetDataTrie(&trieMock.TrieStub{ RootCalled: func() ([]byte, error) { return rootHash, nil }, - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { go func() { time.Sleep(timeSleep) for _, leafBuff := range leaves { diff --git a/node/trieIterators/stakeValuesProcessor.go b/node/trieIterators/stakeValuesProcessor.go index 17109690b98..843725d1067 100644 --- a/node/trieIterators/stakeValuesProcessor.go +++ b/node/trieIterators/stakeValuesProcessor.go @@ -13,7 +13,6 @@ import ( "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/vm" ) @@ -91,17 +90,12 @@ func (svp *stakedValuesProcessor) computeBaseStakedAndTopUp(ctx context.Context) return nil, nil, err } - rootHash, err := validatorAccount.DataTrie().RootHash() - if err != nil { - return nil, nil, err - } - // TODO investigate if a call to GetAllLeavesKeysOnChannel (without values) might increase performance chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = validatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) + err = validatorAccount.GetAllLeaves(chLeaves, ctx) if err != nil { return nil, nil, err } diff --git a/node/trieIterators/stakeValuesProcessor_test.go b/node/trieIterators/stakeValuesProcessor_test.go index 989cf102fde..e664a42ed41 100644 --- a/node/trieIterators/stakeValuesProcessor_test.go +++ b/node/trieIterators/stakeValuesProcessor_test.go @@ -15,6 +15,9 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/vm" @@ -165,7 +168,12 @@ func TestTotalStakedValueProcessor_GetTotalStakedValue_CannotGetRootHash(t *test t.Parallel() expectedErr := errors.New("expected error") - acc, _ := state.NewUserAccount([]byte("newaddress")) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + acc, _ := state.NewUserAccount([]byte("newaddress"), argsAccCreation) acc.SetDataTrie(&trieMock.TrieStub{ RootCalled: func() ([]byte, error) { return nil, expectedErr @@ -191,9 +199,14 @@ func TestTotalStakedValueProcessor_GetTotalStakedValue_CannotGetRootHash(t *test func TestTotalStakedValueProcessor_GetTotalStakedValue_ContextShouldTimeout(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount([]byte("newaddress")) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + acc, _ := state.NewUserAccount([]byte("newaddress"), argsAccCreation) acc.SetDataTrie(&trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { time.Sleep(time.Second) close(leavesChannels.LeavesChan) leavesChannels.ErrChan.Close() @@ -227,9 +240,14 @@ func TestTotalStakedValueProcessor_GetTotalStakedValue_CannotGetAllLeaves(t *tes t.Parallel() expectedErr := errors.New("expected error") - acc, _ := state.NewUserAccount([]byte("newaddress")) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + acc, _ := state.NewUserAccount([]byte("newaddress"), argsAccCreation) acc.SetDataTrie(&trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(_ *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(_ *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { return expectedErr }, RootCalled: func() ([]byte, error) { @@ -272,12 +290,17 @@ func TestTotalStakedValueProcessor_GetTotalStakedValue(t *testing.T) { leafKey4 := "0123456783" leafKey5 := "0123456780" leafKey6 := "0123456788" - acc, _ := state.NewUserAccount([]byte("newaddress")) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + acc, _ := state.NewUserAccount([]byte("newaddress"), argsAccCreation) acc.SetDataTrie(&trieMock.TrieStub{ RootCalled: func() ([]byte, error) { return rootHash, nil }, - GetAllLeavesOnChannelCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { go func() { leaf1 := keyValStorage.NewKeyValStorage(rootHash, append(marshalledData, suffix...)) channels.LeavesChan <- leaf1 diff --git a/outport/factory/hostDriverFactory_test.go b/outport/factory/hostDriverFactory_test.go index 834fa793b6c..1362cb5459e 100644 --- a/outport/factory/hostDriverFactory_test.go +++ b/outport/factory/hostDriverFactory_test.go @@ -6,7 +6,7 @@ import ( "github.com/multiversx/mx-chain-communication-go/websocket/data" "github.com/multiversx/mx-chain-go/config" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/require" ) @@ -20,7 +20,7 @@ func TestCreateHostDriver(t *testing.T) { MarshallerType: "json", Mode: data.ModeClient, }, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerStub{}, } driver, err := CreateHostDriver(args) diff --git a/outport/factory/notifierFactory_test.go b/outport/factory/notifierFactory_test.go index ae38fe7964b..d4f20564d30 100644 --- a/outport/factory/notifierFactory_test.go +++ b/outport/factory/notifierFactory_test.go @@ -5,7 +5,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/outport/factory" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/require" ) @@ -17,7 +17,7 @@ func createMockNotifierFactoryArgs() *factory.EventNotifierFactoryArgs { Username: "", Password: "", RequestTimeoutSec: 1, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, } } diff --git a/outport/mock/driverStub.go b/outport/mock/driverStub.go index e9f4e4a56ab..9cac9437aa8 100644 --- a/outport/mock/driverStub.go +++ b/outport/mock/driverStub.go @@ -3,7 +3,7 @@ package mock import ( outportcore "github.com/multiversx/mx-chain-core-go/data/outport" "github.com/multiversx/mx-chain-core-go/marshal" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" ) // DriverStub - @@ -83,7 +83,7 @@ func (d *DriverStub) FinalizedBlock(finalizedBlock *outportcore.FinalizedBlock) // GetMarshaller - func (d *DriverStub) GetMarshaller() marshal.Marshalizer { - return testscommon.MarshalizerMock{} + return marshallerMock.MarshalizerMock{} } // Close - diff --git a/outport/notifier/eventNotifier_test.go b/outport/notifier/eventNotifier_test.go index 60a3d354206..988230cd190 100644 --- a/outport/notifier/eventNotifier_test.go +++ b/outport/notifier/eventNotifier_test.go @@ -10,7 +10,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-go/outport/mock" "github.com/multiversx/mx-chain-go/outport/notifier" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" outportStub "github.com/multiversx/mx-chain-go/testscommon/outport" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,7 +19,7 @@ import ( func createMockEventNotifierArgs() notifier.ArgsEventNotifier { return notifier.ArgsEventNotifier{ HttpClient: &mock.HTTPClientStub{}, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, BlockContainer: &outportStub.BlockContainerStub{}, } } diff --git a/outport/process/alteredaccounts/alteredAccountsProvider_test.go b/outport/process/alteredaccounts/alteredAccountsProvider_test.go index 64275bfff81..a3f6c6238aa 100644 --- a/outport/process/alteredaccounts/alteredAccountsProvider_test.go +++ b/outport/process/alteredaccounts/alteredAccountsProvider_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-go/outport/process/alteredaccounts/shared" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/trie" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -923,7 +924,7 @@ func testExtractAlteredAccountsFromPoolAddressHasMultipleNfts(t *testing.T) { return nil, false, nil }, } - marshaller := testscommon.MarshalizerMock{} + marshaller := marshallerMock.MarshalizerMock{} args.AccountsDB = &state.AccountsStub{ LoadAccountCalled: func(_ []byte) (vmcommon.AccountHandler, error) { trieMock := trie.DataTrieTrackerStub{ diff --git a/outport/process/executionOrder/transactionsExecutionOrder_test.go b/outport/process/executionOrder/transactionsExecutionOrder_test.go index b2e09e47da2..8edef7ab780 100644 --- a/outport/process/executionOrder/transactionsExecutionOrder_test.go +++ b/outport/process/executionOrder/transactionsExecutionOrder_test.go @@ -15,6 +15,8 @@ import ( processOut "github.com/multiversx/mx-chain-go/outport/process" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/require" ) @@ -22,9 +24,9 @@ import ( func newArgStorer() ArgSorter { return ArgSorter{ Hasher: testscommon.KeccakMock{}, - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, MbsStorer: testscommon.CreateMemUnit(), - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsFrontRunningProtectionFlagEnabledField: true, }, } diff --git a/outport/process/factory/check_test.go b/outport/process/factory/check_test.go index dcd5c3cbbdc..44f4d06e55b 100644 --- a/outport/process/factory/check_test.go +++ b/outport/process/factory/check_test.go @@ -8,7 +8,9 @@ import ( "github.com/multiversx/mx-chain-go/outport/process/transactionsfee" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/stretchr/testify/require" @@ -19,7 +21,7 @@ func createArgOutportDataProviderFactory() ArgOutportDataProviderFactory { HasDrivers: false, AddressConverter: testscommon.NewPubkeyConverterMock(32), AccountsDB: &state.AccountsStub{}, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, EsdtDataStorageHandler: &testscommon.EsdtStorageHandlerStub{}, TransactionsStorer: &genericMocks.StorerMock{}, ShardCoordinator: &testscommon.ShardsCoordinatorMock{}, @@ -29,7 +31,7 @@ func createArgOutportDataProviderFactory() ArgOutportDataProviderFactory { EconomicsData: &economicsmocks.EconomicsHandlerMock{}, Hasher: &testscommon.KeccakMock{}, MbsStorer: &genericMocks.StorerMock{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } diff --git a/outport/process/outportDataProvider_test.go b/outport/process/outportDataProvider_test.go index 3aa79ef774a..667ae905fbe 100644 --- a/outport/process/outportDataProvider_test.go +++ b/outport/process/outportDataProvider_test.go @@ -8,13 +8,14 @@ import ( "github.com/multiversx/mx-chain-go/outport/process/transactionsfee" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/stretchr/testify/require" ) func createArgOutportDataProvider() ArgOutportDataProvider { txsFeeProc, _ := transactionsfee.NewTransactionsFeeProcessor(transactionsfee.ArgTransactionsFeeProcessor{ - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, TransactionsStorer: &genericMocks.StorerMock{}, ShardCoordinator: &testscommon.ShardsCoordinatorMock{}, TxFeeCalculator: &mock.EconomicsHandlerMock{}, @@ -29,7 +30,7 @@ func createArgOutportDataProvider() ArgOutportDataProvider { EconomicsData: &mock.EconomicsHandlerMock{}, ShardCoordinator: &testscommon.ShardsCoordinatorMock{}, ExecutionOrderHandler: &mock.ExecutionOrderHandlerStub{}, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, } } diff --git a/outport/process/transactionsfee/transactionsFeeProcessor_test.go b/outport/process/transactionsfee/transactionsFeeProcessor_test.go index 4495b1d0c75..e0efbab8ada 100644 --- a/outport/process/transactionsfee/transactionsFeeProcessor_test.go +++ b/outport/process/transactionsfee/transactionsFeeProcessor_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/outport/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/require" ) @@ -21,7 +22,7 @@ var pubKeyConverter, _ = pubkeyConverter.NewBech32PubkeyConverter(32, "erd") func prepareMockArg() ArgTransactionsFeeProcessor { return ArgTransactionsFeeProcessor{ - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, TransactionsStorer: genericMocks.NewStorerMock(), ShardCoordinator: &testscommon.ShardsCoordinatorMock{}, TxFeeCalculator: &mock.EconomicsHandlerMock{}, diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index d966615e378..8769e7b664c 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -38,6 +38,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/storage/storageunit" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -1741,7 +1742,7 @@ func (bp *baseProcessor) commitTrieEpochRootHashIfNeeded(metaBlock *block.MetaBl LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = userAccountsDb.GetAllLeaves(iteratorChannels, context.Background(), rootHash) + err = userAccountsDb.GetAllLeaves(iteratorChannels, context.Background(), rootHash, parsers.NewMainTrieLeafParser()) if err != nil { return err } @@ -1755,7 +1756,7 @@ func (bp *baseProcessor) commitTrieEpochRootHashIfNeeded(metaBlock *block.MetaBl totalSizeAccountsDataTries := 0 totalSizeCodeLeaves := 0 for leaf := range iteratorChannels.LeavesChan { - userAccount, errUnmarshal := unmarshalUserAccount(leaf.Key(), leaf.Value(), bp.marshalizer) + userAccount, errUnmarshal := bp.unmarshalUserAccount(leaf.Key(), leaf.Value()) if errUnmarshal != nil { numCodeLeaves++ totalSizeCodeLeaves += len(leaf.Value()) @@ -1770,7 +1771,7 @@ func (bp *baseProcessor) commitTrieEpochRootHashIfNeeded(metaBlock *block.MetaBl LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - errDataTrieGet := userAccountsDb.GetAllLeaves(dataTrie, context.Background(), rh) + errDataTrieGet := userAccountsDb.GetAllLeaves(dataTrie, context.Background(), rh, parsers.NewMainTrieLeafParser()) if errDataTrieGet != nil { continue } @@ -1825,12 +1826,20 @@ func (bp *baseProcessor) commitTrieEpochRootHashIfNeeded(metaBlock *block.MetaBl return nil } -func unmarshalUserAccount(address []byte, userAccountsBytes []byte, marshalizer marshal.Marshalizer) (state.UserAccountHandler, error) { - userAccount, err := state.NewUserAccount(address) +func (bp *baseProcessor) unmarshalUserAccount( + address []byte, + userAccountsBytes []byte, +) (state.UserAccountHandler, error) { + argsAccCreation := state.ArgsAccountCreation{ + Hasher: bp.hasher, + Marshaller: bp.marshalizer, + EnableEpochsHandler: bp.enableEpochsHandler, + } + userAccount, err := state.NewUserAccount(address, argsAccCreation) if err != nil { return nil, err } - err = marshalizer.Unmarshal(userAccount, userAccountsBytes) + err = bp.marshalizer.Unmarshal(userAccount, userAccountsBytes) if err != nil { return nil, err } diff --git a/process/block/baseProcess_test.go b/process/block/baseProcess_test.go index 3b93fb7a465..75d5774bcfb 100644 --- a/process/block/baseProcess_test.go +++ b/process/block/baseProcess_test.go @@ -42,10 +42,12 @@ import ( dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/mainFactoryMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/outport" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" @@ -382,7 +384,7 @@ func createComponentHolderMocks() ( RoundField: &mock.RoundHandlerMock{}, ProcessStatusHandlerField: &testscommon.ProcessStatusHandlerStub{}, EpochNotifierField: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } dataComponents := &mock.DataComponentsMock{ @@ -444,7 +446,7 @@ func createMockTransactionCoordinatorArguments( EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -1892,7 +1894,7 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeededShouldWork(t *testing.T) { RootHashCalled: func() ([]byte, error) { return rootHash, nil }, - GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.TrieLeafParser) error { close(channels.LeavesChan) channels.ErrChan.Close() return nil @@ -1936,7 +1938,7 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeeded_GetAllLeaves(t *testing.T RootHashCalled: func() ([]byte, error) { return rootHash, nil }, - GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.TrieLeafParser) error { close(channels.LeavesChan) channels.ErrChan.Close() return expectedErr @@ -1974,7 +1976,7 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeeded_GetAllLeaves(t *testing.T RootHashCalled: func() ([]byte, error) { return rootHash, nil }, - GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error { channels.ErrChan.WriteInChanNonBlocking(expectedErr) close(channels.LeavesChan) return nil @@ -2031,7 +2033,7 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeededShouldUseDataTrieIfNeededW arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) arguments.AccountsDB = map[state.AccountsDbIdentifier]state.AccountsAdapter{ state.UserAccountsState: &stateMock.AccountsStub{ - GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rh []byte) error { + GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rh []byte, _ common.TrieLeafParser) error { if bytes.Equal(rootHash, rh) { calledWithUserAccountRootHash = true close(channels.LeavesChan) @@ -2476,7 +2478,7 @@ func TestBaseProcessor_getIndexOfFirstMiniBlockToBeExecuted(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -2490,7 +2492,7 @@ func TestBaseProcessor_getIndexOfFirstMiniBlockToBeExecuted(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -2534,7 +2536,7 @@ func TestBaseProcessor_getFinalMiniBlocks(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -2549,7 +2551,7 @@ func TestBaseProcessor_getFinalMiniBlocks(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -2663,11 +2665,11 @@ func TestBaseProcessor_checkScheduledMiniBlockValidity(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } expectedErr := errors.New("expected error") - coreComponents.IntMarsh = &testscommon.MarshalizerStub{ + coreComponents.IntMarsh = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, @@ -2697,7 +2699,7 @@ func TestBaseProcessor_checkScheduledMiniBlockValidity(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } coreComponents.Hash = &mock.HasherStub{ @@ -2730,7 +2732,7 @@ func TestBaseProcessor_checkScheduledMiniBlockValidity(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -2811,7 +2813,7 @@ func TestBaseProcessor_setMiniBlockHeaderReservedField(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -2840,7 +2842,7 @@ func TestBaseProcessor_setMiniBlockHeaderReservedField(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true} arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) arguments.ScheduledTxsExecutionHandler = &testscommon.ScheduledTxsExecutionStub{ @@ -2874,7 +2876,7 @@ func TestBaseProcessor_setMiniBlockHeaderReservedField(t *testing.T) { }, } - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -2904,7 +2906,7 @@ func TestBaseProcessor_setMiniBlockHeaderReservedField(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } shardId := uint32(1) diff --git a/process/block/displayBlock_test.go b/process/block/displayBlock_test.go index ccc6eced1a0..187b87edbe2 100644 --- a/process/block/displayBlock_test.go +++ b/process/block/displayBlock_test.go @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-core-go/display" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,7 +37,7 @@ func createMockArgsTransactionCounter() ArgsTransactionCounter { return ArgsTransactionCounter{ AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, Hasher: &testscommon.HasherStub{}, - Marshalizer: &testscommon.MarshalizerMock{}, + Marshalizer: &marshallerMock.MarshalizerMock{}, ShardID: 0, } } diff --git a/process/block/export_test.go b/process/block/export_test.go index a382ac21519..85e1985ae08 100644 --- a/process/block/export_test.go +++ b/process/block/export_test.go @@ -19,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" @@ -113,7 +114,7 @@ func NewShardProcessorEmptyWith3shards( RoundField: &mock.RoundHandlerMock{}, ProcessStatusHandlerField: &testscommon.ProcessStatusHandlerStub{}, EpochNotifierField: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } dataComponents := &mock.DataComponentsMock{ Storage: &storageStubs.ChainStorerStub{}, diff --git a/process/block/metablock_test.go b/process/block/metablock_test.go index 0916ebc80b6..cb3d9ae21e9 100644 --- a/process/block/metablock_test.go +++ b/process/block/metablock_test.go @@ -26,6 +26,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" @@ -54,7 +55,7 @@ func createMockComponentHolders() ( RoundField: &mock.RoundHandlerMock{RoundTimeDuration: time.Second}, ProcessStatusHandlerField: &testscommon.ProcessStatusHandlerStub{}, EpochNotifierField: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } dataComponents := &mock.DataComponentsMock{ @@ -3154,7 +3155,7 @@ func TestMetaProcessor_ProcessEpochStartMetaBlock(t *testing.T) { t.Parallel() coreC, dataC, bootstrapC, statusC := createMockComponentHolders() - enableEpochsHandler, _ := coreC.EnableEpochsHandlerField.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := coreC.EnableEpochsHandlerField.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.StakingV2EnableEpochField = 0 arguments := createMockMetaArguments(coreC, dataC, bootstrapC, statusC) @@ -3187,7 +3188,7 @@ func TestMetaProcessor_ProcessEpochStartMetaBlock(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createMockComponentHolders() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ StakingV2EnableEpochField: 10, } arguments := createMockMetaArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -3389,7 +3390,7 @@ func TestMetaProcessor_CreateEpochStartBodyShouldWork(t *testing.T) { t.Parallel() coreC, dataC, bootstrapC, statusC := createMockComponentHolders() - enableEpochsHandler, _ := coreC.EnableEpochsHandlerField.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := coreC.EnableEpochsHandlerField.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.StakingV2EnableEpochField = 0 arguments := createMockMetaArguments(coreC, dataC, bootstrapC, statusC) @@ -3463,7 +3464,7 @@ func TestMetaProcessor_CreateEpochStartBodyShouldWork(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createMockComponentHolders() - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ StakingV2EnableEpochField: 10, } arguments := createMockMetaArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -3542,7 +3543,7 @@ func TestMetaProcessor_getFinalMiniBlockHashes(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createMockComponentHolders() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: false, } coreComponents.EnableEpochsHandlerField = enableEpochsHandlerStub @@ -3560,7 +3561,7 @@ func TestMetaProcessor_getFinalMiniBlockHashes(t *testing.T) { t.Parallel() coreComponents, dataComponents, bootstrapComponents, statusComponents := createMockComponentHolders() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } coreComponents.EnableEpochsHandlerField = enableEpochsHandlerStub diff --git a/process/block/postprocess/intermediateResults_test.go b/process/block/postprocess/intermediateResults_test.go index 67633564fea..1b70fba4130 100644 --- a/process/block/postprocess/intermediateResults_test.go +++ b/process/block/postprocess/intermediateResults_test.go @@ -17,6 +17,7 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" @@ -39,7 +40,7 @@ func createMockArgsNewIntermediateResultsProcessor() ArgsNewIntermediateResultsP BlockType: block.SmartContractResultBlock, CurrTxs: &mock.TxForCurrentBlockStub{}, EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, } return args @@ -631,7 +632,7 @@ func TestIntermediateResultsProcessor_VerifyInterMiniBlocksBodyShouldPass(t *tes return maxGasLimitPerBlock }, } - enableEpochHandler := &testscommon.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: false} + enableEpochHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: false} args.EnableEpochsHandler = enableEpochHandler irp, err := NewIntermediateResultsProcessor(args) diff --git a/process/block/preprocess/gasComputation_test.go b/process/block/preprocess/gasComputation_test.go index 91c504d97c6..6660b1a92a0 100644 --- a/process/block/preprocess/gasComputation_test.go +++ b/process/block/preprocess/gasComputation_test.go @@ -13,12 +13,13 @@ import ( "github.com/multiversx/mx-chain-go/process/block/preprocess" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func createEnableEpochsHandler() common.EnableEpochsHandler { - return &testscommon.EnableEpochsHandlerStub{ + return &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsSCDeployFlagEnabledField: true, } } @@ -446,7 +447,7 @@ func TestComputeGasProvidedByMiniBlock_ShouldWorkV1(t *testing.T) { } return process.MoveBalance, process.MoveBalance }}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) txHashes := make([][]byte, 0) @@ -526,7 +527,7 @@ func TestComputeGasProvidedByTx_ShouldWorkWhenTxReceiverAddressIsASmartContractC ComputeTransactionTypeCalled: func(tx data.TransactionHandler) (process.TransactionType, process.TransactionType) { return process.SCInvoking, process.SCInvoking }}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) tx := transaction.Transaction{GasLimit: 7, RcvAddr: make([]byte, core.NumInitCharactersForScAddress+1)} diff --git a/process/block/preprocess/gasTracker_test.go b/process/block/preprocess/gasTracker_test.go index a04c92c8d6c..6f75da4aaa3 100644 --- a/process/block/preprocess/gasTracker_test.go +++ b/process/block/preprocess/gasTracker_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/require" ) @@ -98,7 +99,7 @@ func Test_computeGasProvidedSelfSenderMoveBalanceIntra(t *testing.T) { receiverShardID := uint32(0) rcvAddr, _ := hex.DecodeString("addrReceiver" + suffixShard0) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 75000, consumedReceiverShard: 75000, @@ -133,7 +134,7 @@ func Test_computeGasProvidedSelfSenderSCCallIntra(t *testing.T) { receiverShardID := uint32(0) rcvAddr, _ := hex.DecodeString(smartContractAddressStart + suffixShard0) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 500000, consumedReceiverShard: 500000, @@ -169,7 +170,7 @@ func Test_computeGasProvidedByTxSelfSenderMoveBalanceCross(t *testing.T) { receiverShardID := uint32(1) rcvAddr, _ := hex.DecodeString("addrReceiver" + suffixShard1) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 75000, consumedReceiverShard: 75000, @@ -204,7 +205,7 @@ func Test_computeGasProvidedByTxSelfSenderScCallCross(t *testing.T) { receiverShardID := uint32(1) rcvAddr, _ := hex.DecodeString(smartContractAddressStart + suffixShard1) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 500000, consumedReceiverShard: 500000, @@ -240,7 +241,7 @@ func Test_computeGasProvidedByTxGasHandlerComputeGasErrors(t *testing.T) { receiverShardID := uint32(1) rcvAddr, _ := hex.DecodeString(smartContractAddressStart + suffixShard1) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 500000, consumedReceiverShard: 500000, @@ -283,7 +284,7 @@ func Test_computeGasProvidedByTxGasHandlerRefundGasLargerThanLimit(t *testing.T) receiverShardID := uint32(1) rcvAddr, _ := hex.DecodeString(smartContractAddressStart + suffixShard1) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 500000, consumedReceiverShard: 500000, @@ -329,7 +330,7 @@ func Test_computeGasProvidedWithErrorForGasConsumedForTx(t *testing.T) { receiverShardID := uint32(1) rcvAddr, _ := hex.DecodeString(smartContractAddressStart + suffixShard1) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 75000, consumedReceiverShard: 1600000000, @@ -371,7 +372,7 @@ func Test_computeGasProvidedMaxGasLimitInSenderShardReached(t *testing.T) { receiverShardID := uint32(1) rcvAddr, _ := hex.DecodeString(smartContractAddressStart + suffixShard1) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 75000, consumedReceiverShard: 500000, @@ -408,7 +409,7 @@ func Test_computeGasProvidedMaxGasLimitInReceiverShardReached(t *testing.T) { receiverShardID := uint32(1) rcvAddr, _ := hex.DecodeString(smartContractAddressStart + suffixShard1) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 75000, consumedReceiverShard: 500000, @@ -445,7 +446,7 @@ func Test_computeGasProvidedMaxGasLimitPerBlockReached(t *testing.T) { receiverShardID := uint32(1) rcvAddr, _ := hex.DecodeString(smartContractAddressStart + suffixShard1) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 75000, consumedReceiverShard: 500000, @@ -482,7 +483,7 @@ func Test_computeGasProvidedOK(t *testing.T) { receiverShardID := uint32(1) rcvAddr, _ := hex.DecodeString(smartContractAddressStart + suffixShard1) hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} gcr := &gasConsumedResult{ consumedSenderShard: 75000, consumedReceiverShard: 500000, diff --git a/process/block/preprocess/miniBlockBuilder_test.go b/process/block/preprocess/miniBlockBuilder_test.go index 37a9ca5f94b..d3a04147864 100644 --- a/process/block/preprocess/miniBlockBuilder_test.go +++ b/process/block/preprocess/miniBlockBuilder_test.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -906,7 +907,7 @@ func createWrappedTransaction( receiverShardID uint32, ) *txcache.WrappedTransaction { hasher := &hashingMocks.HasherMock{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} txMarshalled, _ := marshaller.Marshal(tx) txHash := hasher.Compute(string(txMarshalled)) diff --git a/process/block/preprocess/scheduledTxsExecution_test.go b/process/block/preprocess/scheduledTxsExecution_test.go index 7fc6c834249..ab1c7c3d537 100644 --- a/process/block/preprocess/scheduledTxsExecution_test.go +++ b/process/block/preprocess/scheduledTxsExecution_test.go @@ -9,8 +9,6 @@ import ( "testing" "time" - "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" @@ -23,6 +21,8 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageMocks "github.com/multiversx/mx-chain-go/testscommon/storage" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" @@ -966,7 +966,7 @@ func TestScheduledTxsExecution_getScheduledInfoForHeaderShouldFail(t *testing.T) return nil, nil }, }, - &testscommon.MarshalizerStub{ + &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(_ interface{}, _ []byte) error { return expectedErr }, @@ -1008,7 +1008,7 @@ func TestScheduledTxsExecution_getScheduledInfoForHeaderShouldWork(t *testing.T) return marshalledSCRsSavedData, nil }, }, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, &mock.ShardCoordinatorStub{}, ) @@ -1051,7 +1051,7 @@ func TestScheduledTxsExecution_getMarshalledScheduledInfoShouldWork(t *testing.T &testscommon.TxProcessorMock{}, &testscommon.TransactionCoordinatorMock{}, genericMocks.NewStorerMock(), - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, &mock.ShardCoordinatorStub{}, ) @@ -1120,7 +1120,7 @@ func TestScheduledTxsExecution_RollBackToBlockShouldWork(t *testing.T) { return marshalledSCRsSavedData, nil }, }, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, &mock.ShardCoordinatorStub{}, ) @@ -1184,7 +1184,7 @@ func TestScheduledTxsExecution_SaveState(t *testing.T) { return nil }, }, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, &mock.ShardCoordinatorStub{}, ) @@ -1213,7 +1213,7 @@ func TestScheduledTxsExecution_SaveStateIfNeeded(t *testing.T) { return nil }, }, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, &mock.ShardCoordinatorStub{}, ) @@ -1491,7 +1491,7 @@ func TestScheduledTxsExecution_GetScheduledRootHashForHeaderShouldWork(t *testin return marshalledSCRsSavedData, nil }, }, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, &mock.ShardCoordinatorStub{}, ) @@ -1600,7 +1600,7 @@ func TestScheduledTxsExecution_setScheduledMiniBlockHashes(t *testing.T) { &testscommon.TxProcessorMock{}, &testscommon.TransactionCoordinatorMock{}, genericMocks.NewStorerMock(), - &testscommon.MarshalizerStub{ + &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, @@ -1627,7 +1627,7 @@ func TestScheduledTxsExecution_setScheduledMiniBlockHashes(t *testing.T) { &testscommon.TxProcessorMock{}, &testscommon.TransactionCoordinatorMock{}, genericMocks.NewStorerMock(), - &testscommon.MarshalizerStub{ + &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { assert.Equal(t, mb, obj) return nil, nil diff --git a/process/block/preprocess/smartContractResults_test.go b/process/block/preprocess/smartContractResults_test.go index af8b9ffaa14..95382f59958 100644 --- a/process/block/preprocess/smartContractResults_test.go +++ b/process/block/preprocess/smartContractResults_test.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -70,7 +71,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilPool(t *testing.T createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -97,7 +98,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilStore(t *testing. createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -124,7 +125,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilHasher(t *testing createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -151,7 +152,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilMarsalizer(t *tes createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -178,7 +179,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilTxProce(t *testin createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -205,7 +206,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilShardCoord(t *tes createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -232,7 +233,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilAccounts(t *testi createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -258,7 +259,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilRequestFunc(t *te createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -285,7 +286,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilGasHandler(t *tes createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -312,7 +313,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorShouldWork(t *testin createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -339,7 +340,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilPubkeyConverter(t nil, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -366,7 +367,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilBlockSizeComputat createMockPubkeyConverter(), nil, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -393,7 +394,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilBalanceComputatio createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, nil, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -447,7 +448,7 @@ func TestScrsPreprocessor_NewSmartContractResultPreprocessorNilProcessedMiniBloc createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, nil, ) @@ -474,7 +475,7 @@ func TestScrsPreProcessor_GetTransactionFromPool(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -511,7 +512,7 @@ func TestScrsPreprocessor_RequestTransactionNothingToRequestAsGeneratedAtProcess createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -550,7 +551,7 @@ func TestScrsPreprocessor_RequestTransactionFromNetwork(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -588,7 +589,7 @@ func TestScrsPreprocessor_RequestBlockTransactionFromMiniBlockFromNetwork(t *tes createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -637,7 +638,7 @@ func TestScrsPreprocessor_ReceivedTransactionShouldEraseRequested(t *testing.T) createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -712,7 +713,7 @@ func TestScrsPreprocessor_GetAllTxsFromMiniBlockShouldWork(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -800,7 +801,7 @@ func TestScrsPreprocessor_GetAllTxsFromMiniBlockShouldWorkEvenIfScrIsMisplaced(t createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -844,7 +845,7 @@ func TestScrsPreprocessor_RemoveBlockDataFromPoolsNilBlockShouldErr(t *testing.T createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -873,7 +874,7 @@ func TestScrsPreprocessor_RemoveBlockDataFromPoolsOK(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -915,7 +916,7 @@ func TestScrsPreprocessor_IsDataPreparedErr(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -944,7 +945,7 @@ func TestScrsPreprocessor_IsDataPrepared(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -978,7 +979,7 @@ func TestScrsPreprocessor_SaveTxsToStorage(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -1035,7 +1036,7 @@ func TestScrsPreprocessor_SaveTxsToStorageShouldSaveCorrectly(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -1114,7 +1115,7 @@ func TestScrsPreprocessor_SaveTxsToStorageMissingTransactionsShouldNotErr(t *tes createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -1161,7 +1162,7 @@ func TestScrsPreprocessor_ProcessBlockTransactionsShouldWork(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -1224,7 +1225,7 @@ func TestScrsPreprocessor_ProcessBlockTransactionsMissingTrieNode(t *testing.T) createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -1261,7 +1262,7 @@ func TestScrsPreprocessor_ProcessBlockTransactionsMissingTrieNode(t *testing.T) func TestScrsPreprocessor_ProcessBlockTransactionsShouldErrMaxGasLimitPerBlockInSelfShardIsReached(t *testing.T) { t.Parallel() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} enableEpochsHandler := enableEpochsHandlerStub tdp := initDataPool() requestTransaction := func(shardID uint32, txHashes [][]byte) {} @@ -1364,7 +1365,7 @@ func TestScrsPreprocessor_ProcessMiniBlock(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -1408,7 +1409,7 @@ func TestScrsPreprocessor_ProcessMiniBlockWrongTypeMiniblockShouldErr(t *testing createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -1471,7 +1472,7 @@ func TestScrsPreprocessor_RestoreBlockDataIntoPools(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -1515,7 +1516,7 @@ func TestScrsPreprocessor_RestoreBlockDataIntoPoolsNilMiniblockPoolShouldErr(t * createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -1549,7 +1550,7 @@ func TestSmartContractResults_CreateBlockStartedShouldEmptyTxHashAndInfo(t *test createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) @@ -1577,7 +1578,7 @@ func TestSmartContractResults_GetAllCurrentUsedTxs(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, ) diff --git a/process/block/preprocess/transactionsV2_test.go b/process/block/preprocess/transactionsV2_test.go index a2b0326068a..a3a863aa4a9 100644 --- a/process/block/preprocess/transactionsV2_test.go +++ b/process/block/preprocess/transactionsV2_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/txcache" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -62,7 +63,7 @@ func createTransactionPreprocessor() *transactions { return false }, }, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{ ComputeTransactionTypeCalled: func(tx data.TransactionHandler) (process.TransactionType, process.TransactionType) { if bytes.Equal(tx.GetRcvAddr(), []byte("smart contract address")) { diff --git a/process/block/preprocess/transactions_test.go b/process/block/preprocess/transactions_test.go index 872472cd218..42f5ae7527d 100644 --- a/process/block/preprocess/transactions_test.go +++ b/process/block/preprocess/transactions_test.go @@ -31,8 +31,10 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/vm" @@ -230,7 +232,7 @@ func createDefaultTransactionsProcessorArgs() ArgsTransactionPreProcessor { PubkeyConverter: createMockPubkeyConverter(), BlockSizeComputation: &testscommon.BlockSizeComputationStub{}, BalanceComputation: &testscommon.BalanceComputationStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -838,7 +840,7 @@ func TestTransactions_GetTotalGasConsumedShouldWork(t *testing.T) { var gasPenalized uint64 args := createDefaultTransactionsProcessorArgs() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} args.EnableEpochsHandler = enableEpochsHandlerStub args.GasHandler = &mock.GasHandlerMock{ TotalGasProvidedCalled: func() uint64 { @@ -877,7 +879,7 @@ func TestTransactions_UpdateGasConsumedWithGasRefundedAndGasPenalizedShouldWork( var gasPenalized uint64 args := createDefaultTransactionsProcessorArgs() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} args.EnableEpochsHandler = enableEpochsHandlerStub args.GasHandler = &mock.GasHandlerMock{ GasRefundedCalled: func(_ []byte) uint64 { @@ -1077,7 +1079,7 @@ func BenchmarkSortTransactionsByNonceAndSender_WhenReversedNoncesWithFrontRunnin basePreProcess: &basePreProcess{ hasher: hasher, marshalizer: marshaller, - enableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + enableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, } numRands := 1000 @@ -1292,7 +1294,7 @@ func TestTransactionsPreprocessor_ProcessMiniBlockShouldErrMaxGasLimitUsedForDes } args := createDefaultTransactionsProcessorArgs() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} args.EnableEpochsHandler = enableEpochsHandlerStub args.TxDataPool = tdp.Transactions() args.GasHandler = &mock.GasHandlerMock{ @@ -1386,7 +1388,7 @@ func TestTransactionsPreprocessor_SplitMiniBlocksIfNeededShouldWork(t *testing.T txGasLimit := uint64(100) args := createDefaultTransactionsProcessorArgs() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} args.EnableEpochsHandler = enableEpochsHandlerStub args.EconomicsFee = &economicsmocks.EconomicsHandlerStub{ MaxGasLimitPerMiniBlockForSafeCrossShardCalled: func() uint64 { @@ -1709,7 +1711,7 @@ func TestTransactionsPreProcessor_getRemainingGasPerBlock(t *testing.T) { economicsFee: economicsFee, gasHandler: gasHandler, }, - enableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + enableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, } @@ -1822,7 +1824,7 @@ func TestTransactions_AddTransactions(t *testing.T) { args := createDefaultTransactionsProcessorArgs() txs := []data.TransactionHandler{tx1} expectedErr := errors.New("expected error") - args.Marshalizer = &testscommon.MarshalizerStub{ + args.Marshalizer = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, @@ -1925,7 +1927,7 @@ func TestTransactions_ComputeCacheIdentifier(t *testing.T) { txs := &transactions{ basePreProcess: &basePreProcess{ - enableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + enableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, } @@ -1943,7 +1945,7 @@ func TestTransactions_ComputeCacheIdentifier(t *testing.T) { gasTracker: gasTracker{ shardCoordinator: coordinator, }, - enableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, }, }, @@ -2024,7 +2026,7 @@ func TestTransactions_RestoreBlockDataIntoPools(t *testing.T) { assert.Equal(t, 0, len(mbPool.Keys())) }) t.Run("feat scheduled not activated", func(t *testing.T) { - txs.basePreProcess.enableEpochsHandler = &testscommon.EnableEpochsHandlerStub{} + txs.basePreProcess.enableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{} numRestored, err := txs.RestoreBlockDataIntoPools(body, mbPool) assert.Nil(t, err) @@ -2039,7 +2041,7 @@ func TestTransactions_RestoreBlockDataIntoPools(t *testing.T) { mbPool.Clear() t.Run("feat scheduled activated", func(t *testing.T) { - txs.basePreProcess.enableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + txs.basePreProcess.enableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } @@ -2158,7 +2160,7 @@ func TestTransactions_getIndexesOfLastTxProcessed(t *testing.T) { t.Parallel() args := createDefaultTransactionsProcessorArgs() - args.Marshalizer = &testscommon.MarshalizerMock{ + args.Marshalizer = &marshallerMock.MarshalizerMock{ Fail: true, } txs, _ := NewTransactionPreprocessor(args) @@ -2168,14 +2170,14 @@ func TestTransactions_getIndexesOfLastTxProcessed(t *testing.T) { pi, err := txs.getIndexesOfLastTxProcessed(miniBlock, headerHandler) assert.Nil(t, pi) - assert.Equal(t, testscommon.ErrMockMarshalizer, err) + assert.Equal(t, marshallerMock.ErrMockMarshalizer, err) }) t.Run("missing mini block header should not get indexes", func(t *testing.T) { t.Parallel() args := createDefaultTransactionsProcessorArgs() - args.Marshalizer = &testscommon.MarshalizerMock{ + args.Marshalizer = &marshallerMock.MarshalizerMock{ Fail: false, } txs, _ := NewTransactionPreprocessor(args) @@ -2192,7 +2194,7 @@ func TestTransactions_getIndexesOfLastTxProcessed(t *testing.T) { t.Parallel() args := createDefaultTransactionsProcessorArgs() - args.Marshalizer = &testscommon.MarshalizerMock{ + args.Marshalizer = &marshallerMock.MarshalizerMock{ Fail: false, } txs, _ := NewTransactionPreprocessor(args) diff --git a/process/block/preprocess/validatorInfoPreProcessor_test.go b/process/block/preprocess/validatorInfoPreProcessor_test.go index 5c016878def..a3e9ac4a410 100644 --- a/process/block/preprocess/validatorInfoPreProcessor_test.go +++ b/process/block/preprocess/validatorInfoPreProcessor_test.go @@ -12,8 +12,10 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,11 +27,11 @@ func TestNewValidatorInfoPreprocessor_NilHasherShouldErr(t *testing.T) { tdp := initDataPool() rtp, err := NewValidatorInfoPreprocessor( nil, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &testscommon.BlockSizeComputationStub{}, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Nil(t, rtp) @@ -46,7 +48,7 @@ func TestNewValidatorInfoPreprocessor_NilMarshalizerShouldErr(t *testing.T) { &testscommon.BlockSizeComputationStub{}, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Nil(t, rtp) @@ -59,11 +61,11 @@ func TestNewValidatorInfoPreprocessor_NilBlockSizeComputationHandlerShouldErr(t tdp := initDataPool() rtp, err := NewValidatorInfoPreprocessor( &hashingMocks.HasherMock{}, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, nil, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Nil(t, rtp) @@ -75,11 +77,11 @@ func TestNewValidatorInfoPreprocessor_NilValidatorInfoPoolShouldErr(t *testing.T rtp, err := NewValidatorInfoPreprocessor( &hashingMocks.HasherMock{}, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &testscommon.BlockSizeComputationStub{}, nil, genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Nil(t, rtp) @@ -92,11 +94,11 @@ func TestNewValidatorInfoPreprocessor_NilStoreShouldErr(t *testing.T) { tdp := initDataPool() rtp, err := NewValidatorInfoPreprocessor( &hashingMocks.HasherMock{}, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &testscommon.BlockSizeComputationStub{}, tdp.ValidatorsInfo(), nil, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Nil(t, rtp) @@ -109,7 +111,7 @@ func TestNewValidatorInfoPreprocessor_NilEnableEpochHandlerShouldErr(t *testing. tdp := initDataPool() rtp, err := NewValidatorInfoPreprocessor( &hashingMocks.HasherMock{}, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &testscommon.BlockSizeComputationStub{}, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), @@ -126,11 +128,11 @@ func TestNewValidatorInfoPreprocessor_OkValsShouldWork(t *testing.T) { tdp := initDataPool() rtp, err := NewValidatorInfoPreprocessor( &hashingMocks.HasherMock{}, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &testscommon.BlockSizeComputationStub{}, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Nil(t, err) assert.NotNil(t, rtp) @@ -142,11 +144,11 @@ func TestNewValidatorInfoPreprocessor_CreateMarshalizedDataShouldWork(t *testing tdp := initDataPool() rtp, _ := NewValidatorInfoPreprocessor( &hashingMocks.HasherMock{}, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &testscommon.BlockSizeComputationStub{}, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) hash := make([][]byte, 0) @@ -162,11 +164,11 @@ func TestNewValidatorInfoPreprocessor_ProcessMiniBlockInvalidMiniBlockTypeShould tdp := initDataPool() rtp, _ := NewValidatorInfoPreprocessor( &hashingMocks.HasherMock{}, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &testscommon.BlockSizeComputationStub{}, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) txHashes := make([][]byte, 0) @@ -191,11 +193,11 @@ func TestNewValidatorInfoPreprocessor_ProcessMiniBlockShouldWork(t *testing.T) { tdp := initDataPool() rtp, _ := NewValidatorInfoPreprocessor( &hashingMocks.HasherMock{}, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &testscommon.BlockSizeComputationStub{}, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) txHashes := make([][]byte, 0) @@ -220,11 +222,11 @@ func TestNewValidatorInfoPreprocessor_ProcessMiniBlockNotFromMeta(t *testing.T) tdp := initDataPool() rtp, _ := NewValidatorInfoPreprocessor( &hashingMocks.HasherMock{}, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &testscommon.BlockSizeComputationStub{}, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) txHashes := make([][]byte, 0) @@ -247,7 +249,7 @@ func TestNewValidatorInfoPreprocessor_RestorePeerBlockIntoPools(t *testing.T) { t.Parallel() hasher := &hashingMocks.HasherMock{} - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} blockSizeComputation := &testscommon.BlockSizeComputationStub{} tdp := initDataPool() @@ -257,7 +259,7 @@ func TestNewValidatorInfoPreprocessor_RestorePeerBlockIntoPools(t *testing.T) { blockSizeComputation, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) txHashes := [][]byte{[]byte("tx_hash1")} @@ -292,7 +294,7 @@ func TestNewValidatorInfoPreprocessor_RestoreOtherBlockTypeIntoPoolsShouldNotRes t.Parallel() hasher := &hashingMocks.HasherMock{} - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} blockSizeComputation := &testscommon.BlockSizeComputationStub{} tdp := initDataPool() @@ -302,7 +304,7 @@ func TestNewValidatorInfoPreprocessor_RestoreOtherBlockTypeIntoPoolsShouldNotRes blockSizeComputation, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) txHashes := [][]byte{[]byte("tx_hash1")} @@ -337,7 +339,7 @@ func TestNewValidatorInfoPreprocessor_RemovePeerBlockFromPool(t *testing.T) { t.Parallel() hasher := &hashingMocks.HasherMock{} - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} blockSizeComputation := &testscommon.BlockSizeComputationStub{} tdp := initDataPool() @@ -347,7 +349,7 @@ func TestNewValidatorInfoPreprocessor_RemovePeerBlockFromPool(t *testing.T) { blockSizeComputation, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) txHashes := [][]byte{[]byte("tx_hash1")} @@ -382,7 +384,7 @@ func TestNewValidatorInfoPreprocessor_RemoveOtherBlockTypeFromPoolShouldNotRemov t.Parallel() hasher := &hashingMocks.HasherMock{} - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} blockSizeComputation := &testscommon.BlockSizeComputationStub{} tdp := initDataPool() @@ -392,7 +394,7 @@ func TestNewValidatorInfoPreprocessor_RemoveOtherBlockTypeFromPoolShouldNotRemov blockSizeComputation, tdp.ValidatorsInfo(), genericMocks.NewChainStorerMock(0), - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) txHashes := [][]byte{[]byte("tx_hash1")} @@ -431,7 +433,7 @@ func TestNewValidatorInfoPreprocessor_RestoreValidatorsInfo(t *testing.T) { expectedErr := errors.New("error") hasher := &hashingMocks.HasherMock{} - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} blockSizeComputation := &testscommon.BlockSizeComputationStub{} storer := &storage.ChainStorerStub{ GetAllCalled: func(unitType dataRetriever.UnitType, keys [][]byte) (map[string][]byte, error) { @@ -445,7 +447,7 @@ func TestNewValidatorInfoPreprocessor_RestoreValidatorsInfo(t *testing.T) { blockSizeComputation, tdp.ValidatorsInfo(), storer, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) miniBlock := &block.MiniBlock{} @@ -457,7 +459,7 @@ func TestNewValidatorInfoPreprocessor_RestoreValidatorsInfo(t *testing.T) { t.Parallel() hasher := &hashingMocks.HasherMock{} - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} blockSizeComputation := &testscommon.BlockSizeComputationStub{} shardValidatorInfoHash := []byte("hash") shardValidatorInfo := &state.ShardValidatorInfo{ @@ -488,7 +490,7 @@ func TestNewValidatorInfoPreprocessor_RestoreValidatorsInfo(t *testing.T) { blockSizeComputation, tdp.ValidatorsInfo(), storer, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) miniBlock := &block.MiniBlock{} @@ -538,11 +540,11 @@ func TestValidatorInfoPreprocessor_SaveTxsToStorageShouldWork(t *testing.T) { vip, _ := NewValidatorInfoPreprocessor( &hashingMocks.HasherMock{}, - &testscommon.MarshalizerMock{}, + &marshallerMock.MarshalizerMock{}, &testscommon.BlockSizeComputationStub{}, tdp.ValidatorsInfo(), storer, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) err := vip.SaveTxsToStorage(nil) diff --git a/process/block/shardblock_test.go b/process/block/shardblock_test.go index 5d0b6364a3a..ef9f42bab4e 100644 --- a/process/block/shardblock_test.go +++ b/process/block/shardblock_test.go @@ -37,6 +37,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/outport" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" @@ -470,7 +471,7 @@ func TestShardProcessor_ProcessBlockWithInvalidTransactionShouldErr(t *testing.T &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -691,7 +692,7 @@ func TestShardProcessor_ProcessBlockWithErrOnProcessBlockTransactionsCallShouldR &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -2594,7 +2595,7 @@ func TestShardProcessor_MarshalizedDataToBroadcastShouldWork(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -2702,7 +2703,7 @@ func TestShardProcessor_MarshalizedDataMarshalWithoutSuccess(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3094,7 +3095,7 @@ func TestShardProcessor_CreateMiniBlocksShouldWorkWithIntraShardTxs(t *testing.T &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3275,7 +3276,7 @@ func TestShardProcessor_RestoreBlockIntoPoolsShouldWork(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -5037,7 +5038,7 @@ func TestShardProcessor_createMiniBlocks(t *testing.T) { tx2 := &transaction.Transaction{Nonce: 1} txs := []data.TransactionHandler{tx1, tx2} - coreComponents.EnableEpochsHandlerField = &testscommon.EnableEpochsHandlerStub{ + coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsScheduledMiniBlocksFlagEnabledField: true, } arguments := CreateMockArgumentsMultiShard(coreComponents, dataComponents, boostrapComponents, statusComponents) diff --git a/process/coordinator/printDoubleTransactionsDetector_test.go b/process/coordinator/printDoubleTransactionsDetector_test.go index e7c5136ade4..0ae2915b872 100644 --- a/process/coordinator/printDoubleTransactionsDetector_test.go +++ b/process/coordinator/printDoubleTransactionsDetector_test.go @@ -7,14 +7,16 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/assert" ) func createMockArgsPrintDoubleTransactionsDetector() ArgsPrintDoubleTransactionsDetector { return ArgsPrintDoubleTransactionsDetector{ - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, Hasher: &testscommon.HasherStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } @@ -131,7 +133,7 @@ func TestPrintDoubleTransactionsDetector_ProcessBlockBody(t *testing.T) { debugCalled := false args := createMockArgsPrintDoubleTransactionsDetector() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsAddFailedRelayedTxToInvalidMBsFlagField: true, } detector, _ := NewPrintDoubleTransactionsDetector(args) diff --git a/process/coordinator/process_test.go b/process/coordinator/process_test.go index 87017fcf030..2981564f5b2 100644 --- a/process/coordinator/process_test.go +++ b/process/coordinator/process_test.go @@ -33,7 +33,9 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -236,7 +238,7 @@ func createMockTransactionCoordinatorArguments() ArgTransactionCoordinator { EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -537,7 +539,7 @@ func createPreProcessorContainer() process.PreProcessorsContainer { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -556,7 +558,7 @@ func createInterimProcessorContainer() process.IntermediateProcessorContainer { Store: initStore(), PoolsHolder: initDataPool([]byte("test_hash1")), EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, } preFactory, _ := shard.NewIntermediateProcessorsContainerFactory(argsFactory) container, _ := preFactory.Create() @@ -634,7 +636,7 @@ func createPreProcessorContainerWithDataPool( &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -902,7 +904,7 @@ func TestTransactionCoordinator_CreateMbsAndProcessCrossShardTransactions(t *tes &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -1086,7 +1088,7 @@ func TestTransactionCoordinator_CreateMbsAndProcessCrossShardTransactionsNilPreP &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -1194,7 +1196,7 @@ func TestTransactionCoordinator_CreateMbsAndProcessTransactionsFromMeNothingToPr &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -1731,7 +1733,7 @@ func TestTransactionCoordinator_ProcessBlockTransactionProcessTxError(t *testing &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -1857,7 +1859,7 @@ func TestTransactionCoordinator_RequestMiniblocks(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -1996,7 +1998,7 @@ func TestShardProcessor_ProcessMiniBlockCompleteWithOkTxsShouldExecuteThemAndNot &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -2136,7 +2138,7 @@ func TestShardProcessor_ProcessMiniBlockCompleteWithErrorWhileProcessShouldCallR &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -2191,7 +2193,7 @@ func TestTransactionCoordinator_VerifyCreatedBlockTransactionsNilOrMiss(t *testi Store: &storageStubs.ChainStorerStub{}, PoolsHolder: tdp, EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, } preFactory, _ := shard.NewIntermediateProcessorsContainerFactory(argsFactory) container, _ := preFactory.Create() @@ -2251,7 +2253,7 @@ func TestTransactionCoordinator_VerifyCreatedBlockTransactionsOk(t *testing.T) { return MaxGasLimitPerBlock }, }, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, } interFactory, _ := shard.NewIntermediateProcessorsContainerFactory(argsFactory) container, _ := interFactory.Create() @@ -2569,7 +2571,7 @@ func TestTransactionCoordinator_VerifyCreatedMiniBlocksShouldReturnWhenEpochIsNo EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ BlockGasAndFeesReCheckEnableEpochField: 1, }, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, @@ -2618,7 +2620,7 @@ func TestTransactionCoordinator_VerifyCreatedMiniBlocksShouldErrMaxGasLimitPerMi }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -2689,7 +2691,7 @@ func TestTransactionCoordinator_VerifyCreatedMiniBlocksShouldErrMaxAccumulatedFe }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -2765,7 +2767,7 @@ func TestTransactionCoordinator_VerifyCreatedMiniBlocksShouldErrMaxDeveloperFees }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -2841,7 +2843,7 @@ func TestTransactionCoordinator_VerifyCreatedMiniBlocksShouldWork(t *testing.T) }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -2900,7 +2902,7 @@ func TestTransactionCoordinator_GetAllTransactionsShouldWork(t *testing.T) { EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -2982,7 +2984,7 @@ func TestTransactionCoordinator_VerifyGasLimitShouldErrMaxGasLimitPerMiniBlockIn }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3074,7 +3076,7 @@ func TestTransactionCoordinator_VerifyGasLimitShouldWork(t *testing.T) { }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3152,7 +3154,7 @@ func TestTransactionCoordinator_CheckGasProvidedByMiniBlockInReceiverShardShould EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3201,7 +3203,7 @@ func TestTransactionCoordinator_CheckGasProvidedByMiniBlockInReceiverShardShould }, }, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3257,7 +3259,7 @@ func TestTransactionCoordinator_CheckGasProvidedByMiniBlockInReceiverShardShould }, }, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3320,7 +3322,7 @@ func TestTransactionCoordinator_CheckGasProvidedByMiniBlockInReceiverShardShould }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3386,7 +3388,7 @@ func TestTransactionCoordinator_CheckGasProvidedByMiniBlockInReceiverShardShould }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3439,7 +3441,7 @@ func TestTransactionCoordinator_VerifyFeesShouldErrMissingTransaction(t *testing EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3497,7 +3499,7 @@ func TestTransactionCoordinator_VerifyFeesShouldErrMaxAccumulatedFeesExceeded(t }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3565,7 +3567,7 @@ func TestTransactionCoordinator_VerifyFeesShouldErrMaxDeveloperFeesExceeded(t *t }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3611,7 +3613,7 @@ func TestTransactionCoordinator_VerifyFeesShouldErrMaxAccumulatedFeesExceededWhe tx1GasLimit := uint64(100) - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} dataPool := initDataPool(txHash) txCoordinatorArgs := ArgTransactionCoordinator{ @@ -3696,7 +3698,7 @@ func TestTransactionCoordinator_VerifyFeesShouldErrMaxDeveloperFeesExceededWhenS tx1GasLimit := uint64(100) - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} dataPool := initDataPool(txHash) txCoordinatorArgs := ArgTransactionCoordinator{ Hasher: &hashingMocks.HasherMock{}, @@ -3780,7 +3782,7 @@ func TestTransactionCoordinator_VerifyFeesShouldWork(t *testing.T) { tx1GasLimit := uint64(100) - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} dataPool := initDataPool(txHash) txCoordinatorArgs := ArgTransactionCoordinator{ Hasher: &hashingMocks.HasherMock{}, @@ -3885,7 +3887,7 @@ func TestTransactionCoordinator_GetMaxAccumulatedAndDeveloperFeesShouldErr(t *te EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -3940,7 +3942,7 @@ func TestTransactionCoordinator_GetMaxAccumulatedAndDeveloperFeesShouldWork(t *t }, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -4009,7 +4011,7 @@ func TestTransactionCoordinator_RevertIfNeededShouldWork(t *testing.T) { EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, TransactionsLogProcessor: &mock.TxLogsProcessorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, DoubleTransactionsDetector: &testscommon.PanicDoubleTransactionsDetector{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -4078,7 +4080,7 @@ func TestTransactionCoordinator_getFinalCrossMiniBlockInfos(t *testing.T) { t.Parallel() args := createMockTransactionCoordinatorArguments() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} args.EnableEpochsHandler = enableEpochsHandlerStub tc, _ := NewTransactionCoordinator(args) enableEpochsHandlerStub.IsScheduledMiniBlocksFlagEnabledField = true @@ -4353,7 +4355,7 @@ func TestTransactionCoordinator_getIndexesOfLastTxProcessed(t *testing.T) { t.Parallel() args := createMockTransactionCoordinatorArguments() - args.Marshalizer = &testscommon.MarshalizerMock{ + args.Marshalizer = &marshallerMock.MarshalizerMock{ Fail: true, } tc, _ := NewTransactionCoordinator(args) @@ -4363,14 +4365,14 @@ func TestTransactionCoordinator_getIndexesOfLastTxProcessed(t *testing.T) { pi, err := tc.getIndexesOfLastTxProcessed(miniBlock, miniBlockHeader) assert.Nil(t, pi) - assert.Equal(t, testscommon.ErrMockMarshalizer, err) + assert.Equal(t, marshallerMock.ErrMockMarshalizer, err) }) t.Run("should get indexes", func(t *testing.T) { t.Parallel() args := createMockTransactionCoordinatorArguments() - args.Marshalizer = &testscommon.MarshalizerMock{ + args.Marshalizer = &marshallerMock.MarshalizerMock{ Fail: false, } tc, _ := NewTransactionCoordinator(args) diff --git a/process/coordinator/transactionType_test.go b/process/coordinator/transactionType_test.go index 2c66a6af68b..b1e6450a041 100644 --- a/process/coordinator/transactionType_test.go +++ b/process/coordinator/transactionType_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" "github.com/multiversx/mx-chain-vm-common-go/parsers" @@ -27,7 +28,7 @@ func createMockArguments() ArgNewTxTypeHandler { BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTMetadataContinuousCleanupFlagEnabledField: true, }, } diff --git a/process/dataValidators/txValidator_test.go b/process/dataValidators/txValidator_test.go index 7037932bb02..2640d0acec0 100644 --- a/process/dataValidators/txValidator_test.go +++ b/process/dataValidators/txValidator_test.go @@ -14,6 +14,9 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" @@ -23,7 +26,12 @@ import ( func getAccAdapter(nonce uint64, balance *big.Int) *stateMock.AccountsStub { accDB := &stateMock.AccountsStub{} accDB.GetExistingAccountCalled = func(address []byte) (handler vmcommon.AccountHandler, e error) { - acc, _ := state.NewUserAccount(address) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + acc, _ := state.NewUserAccount(address, argsAccCreation) acc.Nonce = nonce acc.Balance = balance diff --git a/process/economics/economicsData_test.go b/process/economics/economicsData_test.go index 0ac846787c1..c97b41a984f 100644 --- a/process/economics/economicsData_test.go +++ b/process/economics/economicsData_test.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/smartContract" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" "github.com/stretchr/testify/assert" @@ -98,7 +99,7 @@ func createArgsForEconomicsData(gasModifier float64) economics.ArgsNewEconomicsD args := economics.ArgsNewEconomicsData{ Economics: createDummyEconomicsConfig(feeSettings), EpochNotifier: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsGasPriceModifierFlagEnabledField: true, }, BuiltInFunctionsCostHandler: &mock.BuiltInCostHandlerStub{}, @@ -112,7 +113,7 @@ func createArgsForEconomicsDataRealFees(handler economics.BuiltInFunctionsCostHa args := economics.ArgsNewEconomicsData{ Economics: createDummyEconomicsConfig(feeSettings), EpochNotifier: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsGasPriceModifierFlagEnabledField: true, }, BuiltInFunctionsCostHandler: handler, diff --git a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go index 17213c3e7b1..afc6de41014 100644 --- a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go +++ b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" @@ -686,7 +687,7 @@ func createMockComponentHolders() (*mock.CoreComponentsMock, *mock.CryptoCompone EpochNotifierField: &epochNotifier.EpochNotifierStub{}, TxVersionCheckField: versioning.NewTxVersionChecker(1), HardforkTriggerPubKeyField: providedHardforkPubKey, - EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } multiSigner := cryptoMocks.NewMultiSigner() cryptoComponents := &mock.CryptoComponentsMock{ diff --git a/process/factory/metachain/intermediateProcessorsContainerFactory_test.go b/process/factory/metachain/intermediateProcessorsContainerFactory_test.go index 327ac2b6812..03c4323c140 100644 --- a/process/factory/metachain/intermediateProcessorsContainerFactory_test.go +++ b/process/factory/metachain/intermediateProcessorsContainerFactory_test.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" @@ -27,7 +28,7 @@ func createMockArgsNewIntermediateProcessorsFactory() metachain.ArgsNewIntermedi Store: &storageStubs.ChainStorerStub{}, PoolsHolder: dataRetrieverMock.NewPoolsHolderMock(), EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, } return args } diff --git a/process/factory/metachain/preProcessorsContainerFactory_test.go b/process/factory/metachain/preProcessorsContainerFactory_test.go index 9f504b1a227..e7c70ac375b 100644 --- a/process/factory/metachain/preProcessorsContainerFactory_test.go +++ b/process/factory/metachain/preProcessorsContainerFactory_test.go @@ -10,6 +10,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -35,7 +36,7 @@ func TestNewPreProcessorsContainerFactory_NilShardCoordinator(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -64,7 +65,7 @@ func TestNewPreProcessorsContainerFactory_NilStore(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -93,7 +94,7 @@ func TestNewPreProcessorsContainerFactory_NilMarshalizer(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -122,7 +123,7 @@ func TestNewPreProcessorsContainerFactory_NilHasher(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -151,7 +152,7 @@ func TestNewPreProcessorsContainerFactory_NilDataPool(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -180,7 +181,7 @@ func TestNewPreProcessorsContainerFactory_NilAccounts(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -209,7 +210,7 @@ func TestNewPreProcessorsContainerFactory_NilFeeHandler(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -238,7 +239,7 @@ func TestNewPreProcessorsContainerFactory_NilTxProcessor(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -267,7 +268,7 @@ func TestNewPreProcessorsContainerFactory_NilRequestHandler(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -295,7 +296,7 @@ func TestNewPreProcessorsContainerFactory_NilGasHandler(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -323,7 +324,7 @@ func TestNewPreProcessorsContainerFactory_NilBlockTracker(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -351,7 +352,7 @@ func TestNewPreProcessorsContainerFactory_NilPubkeyConverter(t *testing.T) { nil, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -379,7 +380,7 @@ func TestNewPreProcessorsContainerFactory_NilBlockSizeComputationHandler(t *test createMockPubkeyConverter(), nil, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -407,7 +408,7 @@ func TestNewPreProcessorsContainerFactory_NilBalanceComputationHandler(t *testin createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, nil, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -463,7 +464,7 @@ func TestNewPreProcessorsContainerFactory_NilTxTypeHandler(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, nil, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -491,7 +492,7 @@ func TestNewPreProcessorsContainerFactory_NilScheduledTxsExecutionHandler(t *tes createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, nil, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -519,7 +520,7 @@ func TestNewPreProcessorsContainerFactory_NilProcessedMiniBlocksTracker(t *testi createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, nil, @@ -547,7 +548,7 @@ func TestNewPreProcessorsContainerFactory(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -582,7 +583,7 @@ func TestPreProcessorsContainerFactory_CreateErrTxPreproc(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -615,7 +616,7 @@ func TestPreProcessorsContainerFactory_Create(t *testing.T) { createMockPubkeyConverter(), &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, diff --git a/process/factory/metachain/vmContainerFactory_test.go b/process/factory/metachain/vmContainerFactory_test.go index 116c23d225f..679cf8d67a3 100644 --- a/process/factory/metachain/vmContainerFactory_test.go +++ b/process/factory/metachain/vmContainerFactory_test.go @@ -14,6 +14,7 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" @@ -66,7 +67,7 @@ func createVmContainerMockArgument(gasSchedule core.GasScheduleNotifier) ArgsNew ValidatorAccountsDB: &stateMock.AccountsStub{}, ChanceComputer: &mock.RaterMock{}, ShardCoordinator: &mock.ShardCoordinatorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsStakeFlagEnabledField: true, }, } @@ -285,7 +286,7 @@ func TestVmContainerFactory_Create(t *testing.T) { }, }, EpochNotifier: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, BuiltInFunctionsCostHandler: &mock.BuiltInCostHandlerStub{}, TxVersionChecker: &testscommon.TxVersionCheckerStub{}, } @@ -344,7 +345,7 @@ func TestVmContainerFactory_Create(t *testing.T) { ValidatorAccountsDB: &stateMock.AccountsStub{}, ChanceComputer: &mock.RaterMock{}, ShardCoordinator: mock.NewMultiShardsCoordinatorMock(1), - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } vmf, err := NewVMContainerFactory(argsNewVMContainerFactory) assert.NotNil(t, vmf) diff --git a/process/factory/shard/intermediateProcessorsContainerFactory_test.go b/process/factory/shard/intermediateProcessorsContainerFactory_test.go index a3aae67c19c..a2e9ecec971 100644 --- a/process/factory/shard/intermediateProcessorsContainerFactory_test.go +++ b/process/factory/shard/intermediateProcessorsContainerFactory_test.go @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" @@ -61,7 +62,7 @@ func createMockArgsNewIntermediateProcessorsFactory() shard.ArgsNewIntermediateP Store: &storageStubs.ChainStorerStub{}, PoolsHolder: createDataPools(), EconomicsFee: &economicsmocks.EconomicsHandlerStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{IsKeepExecOrderOnCreatedSCRsEnabledField: true}, } return args } diff --git a/process/factory/shard/preProcessorsContainerFactory_test.go b/process/factory/shard/preProcessorsContainerFactory_test.go index 5eec22fc5d2..a75a58d5fe2 100644 --- a/process/factory/shard/preProcessorsContainerFactory_test.go +++ b/process/factory/shard/preProcessorsContainerFactory_test.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -40,7 +41,7 @@ func TestNewPreProcessorsContainerFactory_NilShardCoordinator(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -71,7 +72,7 @@ func TestNewPreProcessorsContainerFactory_NilStore(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -102,7 +103,7 @@ func TestNewPreProcessorsContainerFactory_NilMarshalizer(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -133,7 +134,7 @@ func TestNewPreProcessorsContainerFactory_NilHasher(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -164,7 +165,7 @@ func TestNewPreProcessorsContainerFactory_NilDataPool(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -195,7 +196,7 @@ func TestNewPreProcessorsContainerFactory_NilAddrConv(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -226,7 +227,7 @@ func TestNewPreProcessorsContainerFactory_NilAccounts(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -257,7 +258,7 @@ func TestNewPreProcessorsContainerFactory_NilTxProcessor(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -288,7 +289,7 @@ func TestNewPreProcessorsContainerFactory_NilSCProcessor(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -319,7 +320,7 @@ func TestNewPreProcessorsContainerFactory_NilSCR(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -350,7 +351,7 @@ func TestNewPreProcessorsContainerFactory_NilRewardTxProcessor(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -381,7 +382,7 @@ func TestNewPreProcessorsContainerFactory_NilRequestHandler(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -412,7 +413,7 @@ func TestNewPreProcessorsContainerFactory_NilFeeHandler(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -443,7 +444,7 @@ func TestNewPreProcessorsContainerFactory_NilGasHandler(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -474,7 +475,7 @@ func TestNewPreProcessorsContainerFactory_NilBlockTracker(t *testing.T) { nil, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -505,7 +506,7 @@ func TestNewPreProcessorsContainerFactory_NilBlockSizeComputationHandler(t *test &mock.BlockTrackerMock{}, nil, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -536,7 +537,7 @@ func TestNewPreProcessorsContainerFactory_NilBalanceComputationHandler(t *testin &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, nil, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -598,7 +599,7 @@ func TestNewPreProcessorsContainerFactory_NilTxTypeHandler(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, nil, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -629,7 +630,7 @@ func TestNewPreProcessorsContainerFactory_NilScheduledTxsExecutionHandler(t *tes &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, nil, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -660,7 +661,7 @@ func TestNewPreProcessorsContainerFactory_NilProcessedMiniBlocksTracker(t *testi &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, nil, @@ -691,7 +692,7 @@ func TestNewPreProcessorsContainerFactory(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -727,7 +728,7 @@ func TestPreProcessorsContainerFactory_CreateErrTxPreproc(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -769,7 +770,7 @@ func TestPreProcessorsContainerFactory_CreateErrScrPreproc(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, @@ -814,7 +815,7 @@ func TestPreProcessorsContainerFactory_Create(t *testing.T) { &mock.BlockTrackerMock{}, &testscommon.BlockSizeComputationStub{}, &testscommon.BalanceComputationStub{}, - &testscommon.EnableEpochsHandlerStub{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &testscommon.TxTypeHandlerMock{}, &testscommon.ScheduledTxsExecutionStub{}, &testscommon.ProcessedMiniBlocksTrackerStub{}, diff --git a/process/factory/shard/vmContainerFactory_test.go b/process/factory/shard/vmContainerFactory_test.go index b27914ad1af..afab74e4399 100644 --- a/process/factory/shard/vmContainerFactory_test.go +++ b/process/factory/shard/vmContainerFactory_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" vmcommonBuiltInFunctions "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" @@ -40,7 +41,7 @@ func createMockVMAccountsArguments() ArgVMContainerFactory { BlockGasLimit: 10000, GasSchedule: testscommon.NewGasScheduleNotifierMock(wasmConfig.MakeGasMapForTests()), EpochNotifier: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, WasmVMChangeLocker: &sync.RWMutex{}, ESDTTransferParser: esdtTransferParser, BuiltInFunctions: vmcommonBuiltInFunctions.NewBuiltInFunctionContainer(), diff --git a/process/gasCost.go b/process/gasCost.go index 25b0dfe2881..41d6048ccff 100644 --- a/process/gasCost.go +++ b/process/gasCost.go @@ -31,6 +31,8 @@ type BuiltInCost struct { SetGuardian uint64 GuardAccount uint64 UnGuardAccount uint64 + TrieLoadPerNode uint64 + TrieStorePerNode uint64 } // GasCost holds all the needed gas costs for system smart contracts diff --git a/process/guardian/guardedAccount_test.go b/process/guardian/guardedAccount_test.go index d6550babd55..65c80532d28 100644 --- a/process/guardian/guardedAccount_test.go +++ b/process/guardian/guardedAccount_test.go @@ -11,8 +11,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/guardians" "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMocks "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/testscommon/vmcommonMocks" @@ -21,7 +21,7 @@ import ( ) func TestNewGuardedAccount(t *testing.T) { - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} en := &epochNotifier.EpochNotifierStub{} ga, err := NewGuardedAccount(marshaller, en, 10) require.Nil(t, err) @@ -138,7 +138,7 @@ func TestGuardedAccount_getConfiguredGuardians(t *testing.T) { expectedErr := errors.New("expected error") ga := createGuardedAccountWithEpoch(10) - ga.marshaller = &testscommon.MarshalizerStub{ + ga.marshaller = &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { return expectedErr }, @@ -188,7 +188,7 @@ func TestGuardedAccount_saveAccountGuardians(t *testing.T) { expectedErr := errors.New("expected error") ga := createGuardedAccountWithEpoch(10) - ga.marshaller = &testscommon.MarshalizerStub{ + ga.marshaller = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr }, @@ -1160,14 +1160,14 @@ func TestGuardedAccount_IsInterfaceNil(t *testing.T) { var ga *guardedAccount require.True(t, check.IfNil(ga)) - ga, _ = NewGuardedAccount(&testscommon.MarshalizerMock{}, &epochNotifier.EpochNotifierStub{}, 10) + ga, _ = NewGuardedAccount(&marshallerMock.MarshalizerMock{}, &epochNotifier.EpochNotifierStub{}, 10) require.False(t, check.IfNil(ga)) } func TestGuardedAccount_EpochConcurrency(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} currentEpoch := uint32(0) en := forking.NewGenericEpochNotifier() ga, _ := NewGuardedAccount(marshaller, en, 2) @@ -1206,7 +1206,7 @@ func TestGuardedAccount_EpochConcurrency(t *testing.T) { } func createGuardedAccountWithEpoch(epoch uint32) *guardedAccount { - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} en := &epochNotifier.EpochNotifierStub{ RegisterNotifyHandlerCalled: func(handler vmcommon.EpochSubscriberHandler) { handler.EpochConfirmed(epoch, 0) diff --git a/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go b/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go index d77a5ff5ea9..0912de698c1 100644 --- a/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go +++ b/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go @@ -17,6 +17,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" @@ -72,7 +73,7 @@ func createMockComponentHolders() (*mock.CoreComponentsMock, *mock.CryptoCompone TxVersionCheckField: versioning.NewTxVersionChecker(1), EpochNotifierField: &epochNotifier.EpochNotifierStub{}, HardforkTriggerPubKeyField: []byte("provided hardfork pub key"), - EnableEpochsHandlerField: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } cryptoComponents := &mock.CryptoComponentsMock{ BlockSig: createMockSigner(), diff --git a/process/interceptors/factory/interceptedValidatorInfoDataFactory_test.go b/process/interceptors/factory/interceptedValidatorInfoDataFactory_test.go index 749d2fe38df..a46f327c4f3 100644 --- a/process/interceptors/factory/interceptedValidatorInfoDataFactory_test.go +++ b/process/interceptors/factory/interceptedValidatorInfoDataFactory_test.go @@ -7,7 +7,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -21,7 +21,7 @@ func createMockValidatorInfoBuff() []byte { Rating: 10, } - marshalizerMock := testscommon.MarshalizerMock{} + marshalizerMock := marshallerMock.MarshalizerMock{} buff, _ := marshalizerMock.Marshal(vi) return buff diff --git a/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go b/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go index c540e2a002c..38a56751f05 100644 --- a/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go +++ b/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" ) @@ -26,7 +27,7 @@ func createPeerAuthenticationInterceptorProcessArg() processor.ArgPeerAuthentica return processor.ArgPeerAuthenticationInterceptorProcessor{ PeerAuthenticationCacher: testscommon.NewCacherStub(), PeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, HardforkTrigger: &testscommon.HardforkTriggerStub{}, } } @@ -149,7 +150,7 @@ func TestPeerAuthenticationInterceptorProcessor_Save(t *testing.T) { expectedError := errors.New("expected error") args := createPeerAuthenticationInterceptorProcessArg() - args.Marshaller = &testscommon.MarshalizerStub{ + args.Marshaller = &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { return expectedError }, diff --git a/process/interceptors/processor/validatorInfoInterceptorProcessor_test.go b/process/interceptors/processor/validatorInfoInterceptorProcessor_test.go index c21c9c9291c..d4b56cdc430 100644 --- a/process/interceptors/processor/validatorInfoInterceptorProcessor_test.go +++ b/process/interceptors/processor/validatorInfoInterceptorProcessor_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -29,7 +30,7 @@ func createMockValidatorInfo() state.ValidatorInfo { func createMockInterceptedValidatorInfo() process.InterceptedData { args := peer.ArgInterceptedValidatorInfo{ - Marshalizer: testscommon.MarshalizerMock{}, + Marshalizer: marshallerMock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, } args.DataBuff, _ = args.Marshalizer.Marshal(createMockValidatorInfo()) @@ -102,7 +103,7 @@ func TestValidatorInfoInterceptorProcessor_Save(t *testing.T) { providedData := createMockInterceptedValidatorInfo() wasHasOrAddCalled := false args := createMockArgValidatorInfoInterceptorProcessor() - providedBuff, _ := testscommon.MarshalizerMock{}.Marshal(createMockValidatorInfo()) + providedBuff, _ := marshallerMock.MarshalizerMock{}.Marshal(createMockValidatorInfo()) hasher := hashingMocks.HasherMock{} providedHash := hasher.Compute(string(providedBuff)) diff --git a/process/peer/interceptedValidatorInfo_test.go b/process/peer/interceptedValidatorInfo_test.go index c71521e3f53..f4bc60fe30a 100644 --- a/process/peer/interceptedValidatorInfo_test.go +++ b/process/peer/interceptedValidatorInfo_test.go @@ -8,8 +8,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,7 +17,7 @@ import ( func createMockArgInterceptedValidatorInfo() ArgInterceptedValidatorInfo { args := ArgInterceptedValidatorInfo{ - Marshalizer: testscommon.MarshalizerMock{}, + Marshalizer: marshallerMock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, } args.DataBuff, _ = args.Marshalizer.Marshal(createMockShardValidatorInfo()) @@ -63,7 +63,7 @@ func TestNewInterceptedValidatorInfo(t *testing.T) { expectedErr := errors.New("expected err") args := createMockArgInterceptedValidatorInfo() - args.Marshalizer = &testscommon.MarshalizerStub{ + args.Marshalizer = &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { return expectedErr }, diff --git a/process/peer/process.go b/process/peer/process.go index f7d15ed7917..08234eadfc5 100644 --- a/process/peer/process.go +++ b/process/peer/process.go @@ -21,6 +21,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -568,7 +569,7 @@ func (vs *validatorStatistics) GetValidatorInfoForRootHash(rootHash []byte) (map LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err := vs.peerAdapter.GetAllLeaves(leavesChannels, context.Background(), rootHash) + err := vs.peerAdapter.GetAllLeaves(leavesChannels, context.Background(), rootHash, parsers.NewMainTrieLeafParser()) if err != nil { return nil, err } diff --git a/process/peer/process_test.go b/process/peer/process_test.go index 383d17d51e1..e71684c4c19 100644 --- a/process/peer/process_test.go +++ b/process/peer/process_test.go @@ -25,6 +25,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" @@ -98,7 +99,7 @@ func createMockArguments() peer.ArgValidatorStatisticsProcessor { }, }, EpochNotifier: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, BuiltInFunctionsCostHandler: &mock.BuiltInCostHandlerStub{}, TxVersionChecker: &testscommon.TxVersionCheckerStub{}, } @@ -121,7 +122,7 @@ func createMockArguments() peer.ArgValidatorStatisticsProcessor { MaxComputableRounds: 1000, MaxConsecutiveRoundsOfRatingDecrease: 2000, NodesSetup: &mock.NodesSetupStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsSwitchJailWaitingFlagEnabledField: true, IsBelowSignedThresholdFlagEnabledField: true, }, @@ -1383,7 +1384,7 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksMissedRoundsGreaterTha arguments.PeerAdapter = peerAdapter arguments.NodesCoordinator = nodesCoordinatorMock arguments.MaxComputableRounds = 1 - enableEpochsHandler, _ := arguments.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := arguments.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStopDecreasingValidatorRatingWhenStuckFlagEnabledField = false arguments.MaxConsecutiveRoundsOfRatingDecrease = 4 @@ -1940,7 +1941,7 @@ func TestValidatorStatistics_RootHashWithErrShouldReturnNil(t *testing.T) { arguments := createMockArguments() peerAdapter := getAccountsMock() - peerAdapter.GetAllLeavesCalled = func(_ *common.TrieIteratorChannels, _ context.Context, _ []byte) error { + peerAdapter.GetAllLeavesCalled = func(_ *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { return expectedErr } arguments.PeerAdapter = peerAdapter @@ -1967,7 +1968,7 @@ func TestValidatorStatistics_ResetValidatorStatisticsAtNewEpoch(t *testing.T) { marshalizedPa0, _ := arguments.Marshalizer.Marshal(pa0) peerAdapter := getAccountsMock() - peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, _ context.Context, rootHash []byte) error { + peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, _ context.Context, rootHash []byte, _ common.TrieLeafParser) error { if bytes.Equal(rootHash, hash) { go func() { ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytes0, marshalizedPa0) @@ -2029,7 +2030,7 @@ func TestValidatorStatistics_Process(t *testing.T) { marshalizedPaMeta, _ := arguments.Marshalizer.Marshal(paMeta) peerAdapter := getAccountsMock() - peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.TrieLeafParser) error { if bytes.Equal(rootHash, hash) { go func() { ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytes0, marshalizedPa0) @@ -2078,7 +2079,7 @@ func TestValidatorStatistics_GetValidatorInfoForRootHash(t *testing.T) { peerAdapter := getAccountsMock() expectedErr := errors.New("expected error") - peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.TrieLeafParser) error { if bytes.Equal(rootHash, hash) { go func() { ch.ErrChan.WriteInChanNonBlocking(expectedErr) @@ -2105,7 +2106,7 @@ func TestValidatorStatistics_GetValidatorInfoForRootHash(t *testing.T) { marshalizedPaMeta, _ := arguments.Marshalizer.Marshal(paMeta) peerAdapter := getAccountsMock() - peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.TrieLeafParser) error { if bytes.Equal(rootHash, hash) { go func() { ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytes0, marshalizedPa0) @@ -2297,7 +2298,7 @@ func TestValidatorStatistics_ProcessValidatorInfosEndOfEpochV2ComputesEligibleLe arguments.Rater = rater updateArgumentsWithNeeded(arguments) - enableEpochsHandler, _ := arguments.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := arguments.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) validatorStatistics, _ := peer.NewValidatorStatisticsProcessor(arguments) enableEpochsHandler.IsStakingV2FlagEnabledForActivationEpochCompletedField = true @@ -2553,7 +2554,7 @@ func updateArgumentsWithNeeded(arguments peer.ArgValidatorStatisticsProcessor) { marshalizedPaMeta, _ := arguments.Marshalizer.Marshal(paMeta) peerAdapter := getAccountsMock() - peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { + peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.TrieLeafParser) error { go func() { ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytes0, marshalizedPa0) ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytesMeta, marshalizedPaMeta) diff --git a/process/receipts/receiptsRepository_test.go b/process/receipts/receiptsRepository_test.go index 593bf80dac9..7d5cdd37030 100644 --- a/process/receipts/receiptsRepository_test.go +++ b/process/receipts/receiptsRepository_test.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" testsCommonStorage "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/require" ) @@ -35,7 +36,7 @@ func TestNewReceiptsRepository(t *testing.T) { t.Run("NilHasher", func(t *testing.T) { arguments := ArgsNewReceiptsRepository{ - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, Hasher: nil, Store: genericMocks.NewChainStorerMock(0), } @@ -48,7 +49,7 @@ func TestNewReceiptsRepository(t *testing.T) { t.Run("NilStorer", func(t *testing.T) { arguments := ArgsNewReceiptsRepository{ - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, Hasher: &testscommon.HasherStub{}, Store: nil, } @@ -62,7 +63,7 @@ func TestNewReceiptsRepository(t *testing.T) { t.Run("storer not found", func(t *testing.T) { expectedErr := errors.New("expected error") arguments := ArgsNewReceiptsRepository{ - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, Hasher: &testscommon.HasherStub{}, Store: &testsCommonStorage.ChainStorerStub{ GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { @@ -79,7 +80,7 @@ func TestNewReceiptsRepository(t *testing.T) { t.Run("no error", func(t *testing.T) { arguments := ArgsNewReceiptsRepository{ - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, Hasher: &testscommon.HasherStub{}, Store: genericMocks.NewChainStorerMock(0), } @@ -245,7 +246,7 @@ func TestReceiptsRepository_NoPanicOnSaveOrLoadWhenBadStorage(t *testing.T) { } repository, _ := NewReceiptsRepository(ArgsNewReceiptsRepository{ - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, Hasher: &testscommon.HasherStub{}, Store: store, }) @@ -269,7 +270,7 @@ func TestReceiptsRepository_NoPanicOnSaveOrLoadWhenBadStorage(t *testing.T) { func TestReceiptsRepository_DecideStorageKey(t *testing.T) { repository, _ := NewReceiptsRepository(ArgsNewReceiptsRepository{ - Marshaller: testscommon.MarshalizerMock{}, + Marshaller: marshallerMock.MarshalizerMock{}, Hasher: &testscommon.HasherStub{}, Store: genericMocks.NewChainStorerMock(0), }) diff --git a/process/rewardTransaction/process_test.go b/process/rewardTransaction/process_test.go index 97112e792b3..ff64d85f500 100644 --- a/process/rewardTransaction/process_test.go +++ b/process/rewardTransaction/process_test.go @@ -12,6 +12,9 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/rewardTransaction" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/trie" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -188,7 +191,12 @@ func TestRewardTxProcessor_ProcessRewardTransactionShouldWork(t *testing.T) { accountsDb := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount(address) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + return state.NewUserAccount(address, argsAccCreation) }, SaveAccountCalled: func(accountHandler vmcommon.AccountHandler) error { saveAccountWasCalled = true @@ -220,7 +228,7 @@ func TestRewardTxProcessor_ProcessRewardTransactionMissingTrieNode(t *testing.T) missingNodeErr := fmt.Errorf(core.GetNodeFromDBErrorString) accountsDb := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { - acc, _ := state.NewUserAccount(address) + acc := stateMock.NewAccountWrapMock(address) acc.SetDataTrie(&trie.TrieStub{ GetCalled: func(key []byte) ([]byte, uint32, error) { return nil, 0, missingNodeErr @@ -255,7 +263,12 @@ func TestRewardTxProcessor_ProcessRewardTransactionToASmartContractShouldWork(t saveAccountWasCalled := false address := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6} - userAccount, _ := state.NewUserAccount(address) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + userAccount, _ := state.NewUserAccount(address, argsAccCreation) accountsDb := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { return userAccount, nil diff --git a/process/scToProtocol/stakingToPeer_test.go b/process/scToProtocol/stakingToPeer_test.go index b32ffeddcf0..20a7cb3b26a 100644 --- a/process/scToProtocol/stakingToPeer_test.go +++ b/process/scToProtocol/stakingToPeer_test.go @@ -19,7 +19,9 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" @@ -37,7 +39,7 @@ func createMockArgumentsNewStakingToPeer() ArgStakingToPeer { ArgParser: &mock.ArgumentParserMock{}, CurrTxs: &mock.TxForCurrentBlockStub{}, RatingsData: &mock.RatingsInfoMock{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsStakeFlagEnabledField: true, IsValidatorToDelegationFlagEnabledField: true, }, @@ -57,6 +59,16 @@ func createBlockBody() *block.Body { } } +func createStakingScAccount() state.UserAccountHandler { + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + userAcc, _ := state.NewUserAccount(vm.StakingSCAddress, argsAccCreation) + return userAcc +} + func TestNewStakingToPeerNilAddrConverterShouldErr(t *testing.T) { t.Parallel() @@ -256,7 +268,7 @@ func TestStakingToPeer_UpdateProtocolRemoveAccountShouldReturnNil(t *testing.T) } arguments := createMockArgumentsNewStakingToPeer() - userAcc, _ := state.NewUserAccount(vm.StakingSCAddress) + userAcc := createStakingScAccount() baseState := &stateMock.AccountsStub{} baseState.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { return userAcc, nil @@ -310,7 +322,7 @@ func TestStakingToPeer_UpdateProtocolCannotSetRewardAddressShouldErr(t *testing. } marshalizer := &mock.MarshalizerMock{} - userAcc, _ := state.NewUserAccount(vm.StakingSCAddress) + userAcc := createStakingScAccount() baseState := &stateMock.AccountsStub{} baseState.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { return userAcc, nil @@ -366,7 +378,7 @@ func TestStakingToPeer_UpdateProtocolEmptyDataShouldNotAddToTrie(t *testing.T) { return fmt.Errorf("error") } - userAcc, _ := state.NewUserAccount(vm.StakingSCAddress) + userAcc := createStakingScAccount() baseState := &stateMock.AccountsStub{} baseState.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { return userAcc, nil @@ -433,7 +445,7 @@ func TestStakingToPeer_UpdateProtocolCannotSaveAccountShouldErr(t *testing.T) { } marshalizer := &mock.MarshalizerMock{} - userAcc, _ := state.NewUserAccount(vm.StakingSCAddress) + userAcc := createStakingScAccount() baseState := &stateMock.AccountsStub{} baseState.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { return userAcc, nil @@ -496,7 +508,7 @@ func TestStakingToPeer_UpdateProtocolCannotSaveAccountNonceShouldErr(t *testing. } marshalizer := &mock.MarshalizerMock{} - userAcc, _ := state.NewUserAccount(vm.StakingSCAddress) + userAcc := createStakingScAccount() baseState := &stateMock.AccountsStub{} baseState.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { return userAcc, nil @@ -562,7 +574,7 @@ func TestStakingToPeer_UpdateProtocol(t *testing.T) { arguments.CurrTxs = currTx arguments.PeerState = peerState arguments.Marshalizer = marshalizer - userAcc, _ := state.NewUserAccount(vm.StakingSCAddress) + userAcc := createStakingScAccount() baseState := &stateMock.AccountsStub{} baseState.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { return userAcc, nil @@ -621,7 +633,7 @@ func TestStakingToPeer_UpdateProtocolCannotSaveUnStakedNonceShouldErr(t *testing } marshalizer := &mock.MarshalizerMock{} - userAcc, _ := state.NewUserAccount(vm.StakingSCAddress) + userAcc := createStakingScAccount() baseState := &stateMock.AccountsStub{} baseState.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { return userAcc, nil diff --git a/process/smartContract/builtInFunctions/factory_test.go b/process/smartContract/builtInFunctions/factory_test.go index 04c180235b5..abf71000038 100644 --- a/process/smartContract/builtInFunctions/factory_test.go +++ b/process/smartContract/builtInFunctions/factory_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/guardianMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" @@ -33,7 +34,7 @@ func createMockArguments() ArgsCreateBuiltInFunctionContainer { Accounts: &stateMock.AccountsStub{}, ShardCoordinator: mock.NewMultiShardsCoordinatorMock(1), EpochNotifier: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, AutomaticCrawlerAddresses: [][]byte{ bytes.Repeat([]byte{1}, 32), }, @@ -89,6 +90,8 @@ func fillGasMapBuiltInCosts(value uint64) map[string]uint64 { gasMap["ESDTNFTMultiTransfer"] = value gasMap["SetGuardian"] = value gasMap["GuardAccount"] = value + gasMap["TrieLoadPerNode"] = value + gasMap["TrieStorePerNode"] = value return gasMap } @@ -165,7 +168,7 @@ func TestCreateBuiltInFunctionContainer(t *testing.T) { args := createMockArguments() builtInFuncFactory, err := CreateBuiltInFunctionsFactory(args) assert.Nil(t, err) - assert.Equal(t, 35, len(builtInFuncFactory.BuiltInFunctionContainer().Keys())) + assert.Equal(t, 36, len(builtInFuncFactory.BuiltInFunctionContainer().Keys())) err = builtInFuncFactory.SetPayableHandler(&testscommon.BlockChainHookStub{}) assert.Nil(t, err) diff --git a/process/smartContract/hooks/blockChainHook_test.go b/process/smartContract/hooks/blockChainHook_test.go index 75d2b9e37c3..b01c2399d3e 100644 --- a/process/smartContract/hooks/blockChainHook_test.go +++ b/process/smartContract/hooks/blockChainHook_test.go @@ -26,7 +26,10 @@ import ( "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/trie" @@ -58,7 +61,7 @@ func createMockBlockChainHookArgs() hooks.ArgBlockChainHook { DataPool: datapool, CompiledSCPool: datapool.SmartContracts(), EpochNotifier: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, NilCompiledSCStore: true, EnableEpochs: config.EnableEpochs{ DoNotReturnOldBlockInBlockchainHookEnableEpoch: math.MaxUint32, @@ -80,6 +83,16 @@ func createContractCallInput(function string, sender, receiver []byte) *vmcommon } } +func createAccount(address []byte) state.UserAccountHandler { + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + account, _ := state.NewUserAccount(address, argsAccCreation) + return account +} + func TestNewBlockChainHookImpl(t *testing.T) { t.Parallel() @@ -270,10 +283,10 @@ func TestBlockChainHookImpl_GetCode(t *testing.T) { } bh, _ := hooks.NewBlockChainHookImpl(args) - account, _ := state.NewUserAccount([]byte("address")) + account := createAccount([]byte("address")) account.SetCodeHash(expectedCodeHash) - code := bh.GetCode(account) + code := bh.GetCode(account.(vmcommon.UserAccountHandler)) require.Equal(t, expectedCode, code) }) } @@ -330,7 +343,7 @@ func TestBlockChainHookImpl_GetUserAccountWrongTypeShouldErr(t *testing.T) { func TestBlockChainHookImpl_GetUserAccount(t *testing.T) { t.Parallel() - expectedAccount, _ := state.NewUserAccount([]byte("1234")) + expectedAccount := createAccount([]byte("1234")) args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { @@ -338,7 +351,7 @@ func TestBlockChainHookImpl_GetUserAccount(t *testing.T) { }, } bh, _ := hooks.NewBlockChainHookImpl(args) - acc, err := bh.GetUserAccount(expectedAccount.Address) + acc, err := bh.GetUserAccount(expectedAccount.AddressBytes()) assert.Nil(t, err) assert.Equal(t, expectedAccount, acc) @@ -423,7 +436,7 @@ func TestBlockChainHookImpl_GetStorageData(t *testing.T) { expectedErr := errors.New("expected error") args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMaxBlockchainHookCountersFlagEnabledField: true, } args.Counter = &testscommon.BlockChainHookCounterStub{ @@ -474,7 +487,7 @@ func TestBlockChainHookImpl_GetStorageData(t *testing.T) { counterProcessedCalled := false args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMaxBlockchainHookCountersFlagEnabledField: true, } args.Counter = &testscommon.BlockChainHookCounterStub{ @@ -506,7 +519,7 @@ func TestBlockChainHookImpl_GetStorageData(t *testing.T) { counterProcessedCalled := false args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMaxBlockchainHookCountersFlagEnabledField: true, } args.ShardCoordinator = &testscommon.ShardsCoordinatorMock{ @@ -627,7 +640,7 @@ func TestBlockChainHookImpl_NewAddressLengthNoGood(t *testing.T) { acnts := &stateMock.AccountsStub{} acnts.GetExistingAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount(address) + return createAccount(address).(vmcommon.AccountHandler), nil } args := createMockBlockChainHookArgs() args.Accounts = acnts @@ -651,7 +664,7 @@ func TestBlockChainHookImpl_NewAddressVMTypeTooLong(t *testing.T) { acnts := &stateMock.AccountsStub{} acnts.GetExistingAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount(address) + return createAccount(address).(vmcommon.AccountHandler), nil } args := createMockBlockChainHookArgs() args.Accounts = acnts @@ -671,7 +684,7 @@ func TestBlockChainHookImpl_NewAddress(t *testing.T) { acnts := &stateMock.AccountsStub{} acnts.GetExistingAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount(address) + return createAccount(address).(vmcommon.AccountHandler), nil } args := createMockBlockChainHookArgs() args.Accounts = acnts @@ -830,7 +843,7 @@ func TestBlockChainHookImpl_GetBlockhashFromStorerInSameEpochWithFlagEnabled(t * t.Parallel() args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsDoNotReturnOldBlockInBlockchainHookFlagEnabledField: true, } nonce := uint64(10) @@ -988,7 +1001,7 @@ func TestBlockChainHookImpl_GettersFromBlockchainCurrentHeader(t *testing.T) { } args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsDoNotReturnOldBlockInBlockchainHookFlagEnabledField: true, } args.BlockChain = &testscommon.ChainHandlerStub{ @@ -1140,7 +1153,7 @@ func TestBlockChainHookImpl_IsPayablePayableBySC(t *testing.T) { return acc, nil }, } - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsPayableBySCFlagEnabledField: true, } @@ -1747,7 +1760,7 @@ func TestBlockChainHookImpl_ProcessBuiltInFunction(t *testing.T) { t.Parallel() args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMaxBlockchainHookCountersFlagEnabledField: true, } args.BuiltInFunctions = builtInFunctionsContainer @@ -1782,7 +1795,7 @@ func TestBlockChainHookImpl_ProcessBuiltInFunction(t *testing.T) { counterProcessedCalled := false args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMaxBlockchainHookCountersFlagEnabledField: true, } args.BuiltInFunctions = builtInFunctionsContainer @@ -1818,7 +1831,7 @@ func TestBlockChainHookImpl_ProcessBuiltInFunction(t *testing.T) { counterProcessedCalled := false args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMaxBlockchainHookCountersFlagEnabledField: true, } args.BuiltInFunctions = builtInFunctionsContainer @@ -1910,14 +1923,14 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { }, } errMarshaller := errors.New("error marshaller") - args.Marshalizer = &testscommon.MarshalizerStub{ + args.Marshalizer = &marshallerMock.MarshalizerStub{ UnmarshalCalled: func(obj interface{}, buff []byte) error { require.Equal(t, emptyESDTData, obj) require.Equal(t, invalidUnmarshalledData, buff) return errMarshaller }, } - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsOptimizeNFTStoreFlagEnabledField: true, } args.EnableEpochsHandler = enableEpochsHandlerStub @@ -1956,7 +1969,7 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { return addressHandler, nil }, } - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsOptimizeNFTStoreFlagEnabledField: true, } args.EnableEpochsHandler = enableEpochsHandlerStub @@ -1984,7 +1997,7 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { return addressHandler, nil }, } - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsOptimizeNFTStoreFlagEnabledField: true, } args.EnableEpochsHandler = enableEpochsHandlerStub @@ -2011,7 +2024,7 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { return addressHandler, nil }, } - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsOptimizeNFTStoreFlagEnabledField: true, } args.EnableEpochsHandler = enableEpochsHandlerStub @@ -2036,7 +2049,7 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { return addressHandler, nil }, } - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsOptimizeNFTStoreFlagEnabledField: true, } args.EnableEpochsHandler = enableEpochsHandlerStub @@ -2066,7 +2079,7 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { return nil, false, expectedErr }, } - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsOptimizeNFTStoreFlagEnabledField: true, } @@ -2095,7 +2108,7 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { return ©Token, false, nil }, } - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsOptimizeNFTStoreFlagEnabledField: true, } @@ -2136,7 +2149,7 @@ func TestBlockChainHookImpl_ApplyFiltersOnCodeMetadata(t *testing.T) { t.Parallel() args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsPayableBySCFlagEnabledField: true, } bh, _ := hooks.NewBlockChainHookImpl(args) @@ -2208,7 +2221,7 @@ func TestBlockChainHookImpl_FilterCodeMetadataForUpgrade(t *testing.T) { t.Parallel() args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsPayableBySCFlagEnabledField: true, } bh, _ := hooks.NewBlockChainHookImpl(args) @@ -2222,7 +2235,7 @@ func TestBlockChainHookImpl_FilterCodeMetadataForUpgrade(t *testing.T) { t.Parallel() args := createMockBlockChainHookArgs() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsPayableBySCFlagEnabledField: true, } bh, _ := hooks.NewBlockChainHookImpl(args) diff --git a/process/smartContract/process_test.go b/process/smartContract/process_test.go index f13b8ce5d16..f0a29327c85 100644 --- a/process/smartContract/process_test.go +++ b/process/smartContract/process_test.go @@ -27,8 +27,10 @@ import ( "github.com/multiversx/mx-chain-go/storage/txcache" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" @@ -48,14 +50,24 @@ func createMockPubkeyConverter() *testscommon.PubkeyConverterMock { return testscommon.NewPubkeyConverterMock(32) } +func createAccount(address []byte) state.UserAccountHandler { + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + acc, _ := state.NewUserAccount(address, argsAccCreation) + return acc +} + func createAccounts(tx data.TransactionHandler) (state.UserAccountHandler, state.UserAccountHandler) { - acntSrc, _ := state.NewUserAccount(tx.GetSndAddr()) - acntSrc.Balance = acntSrc.Balance.Add(acntSrc.Balance, tx.GetValue()) + acntSrc := createAccount(tx.GetSndAddr()) + _ = acntSrc.AddToBalance(tx.GetValue()) totalFee := big.NewInt(0) totalFee = totalFee.Mul(big.NewInt(int64(tx.GetGasLimit())), big.NewInt(int64(tx.GetGasPrice()))) - acntSrc.Balance.Set(acntSrc.Balance.Add(acntSrc.Balance, totalFee)) + _ = acntSrc.AddToBalance(totalFee) - acntDst, _ := state.NewUserAccount(tx.GetRcvAddr()) + acntDst := createAccount(tx.GetRcvAddr()) return acntSrc, acntDst } @@ -102,7 +114,7 @@ func createMockSmartContractProcessorArguments() ArgsNewSmartContractProcessor { SetGasRefundedCalled: func(gasRefunded uint64, hash []byte) {}, }, GasSchedule: testscommon.NewGasScheduleNotifierMock(gasSchedule), - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsSCDeployFlagEnabledField: true, }, WasmVMChangeLocker: &sync.RWMutex{}, @@ -529,7 +541,7 @@ func TestScProcessor_DeploySmartContractDisabled(t *testing.T) { }} arguments.VmContainer = vmContainer arguments.ArgsParser = argParser - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsBuiltInFunctionsFlagEnabledField: true, } @@ -662,7 +674,7 @@ func TestScProcessor_ExecuteBuiltInFunctionSCResultCallSelfShard(t *testing.T) { arguments.AccountsDB = accountState arguments.VmContainer = vmContainer arguments.ArgsParser = argParser - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} arguments.EnableEpochsHandler = enableEpochsHandlerStub funcName := "builtIn" sc, err := NewSmartContractProcessor(arguments) @@ -722,7 +734,7 @@ func TestScProcessor_ExecuteBuiltInFunctionSCResultCallSelfShardCannotSaveLog(t arguments.AccountsDB = accountState arguments.VmContainer = vmContainer arguments.ArgsParser = argParser - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} arguments.EnableEpochsHandler = enableEpochsHandlerStub funcName := "builtIn" sc, err := NewSmartContractProcessor(arguments) @@ -773,7 +785,7 @@ func TestScProcessor_ExecuteBuiltInFunction(t *testing.T) { arguments.AccountsDB = accountState arguments.VmContainer = vmContainer arguments.ArgsParser = argParser - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} arguments.EnableEpochsHandler = enableEpochsHandlerStub funcName := "builtIn" sc, err := NewSmartContractProcessor(arguments) @@ -816,7 +828,7 @@ func TestScProcessor_ExecuteBuiltInFunctionSCRTooBig(t *testing.T) { arguments.AccountsDB = accountState arguments.VmContainer = vmContainer arguments.ArgsParser = argParser - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsBuiltInFunctionsFlagEnabledField: true, } arguments.EnableEpochsHandler = enableEpochsHandlerStub @@ -2146,7 +2158,7 @@ func TestScProcessor_GetAccountFromAddr(t *testing.T) { getCalled := 0 accountsDB.LoadAccountCalled = func(address []byte) (handler vmcommon.AccountHandler, e error) { getCalled++ - acc, _ := state.NewUserAccount(address) + acc := createAccount(address) return acc, nil } @@ -2250,8 +2262,7 @@ func TestScProcessor_DeleteAccountsInShard(t *testing.T) { accountsDB := &stateMock.AccountsStub{} removeCalled := 0 accountsDB.LoadAccountCalled = func(address []byte) (handler vmcommon.AccountHandler, e error) { - acc, _ := state.NewUserAccount(address) - return acc, nil + return createAccount(address), nil } accountsDB.RemoveAccountCalled = func(address []byte) error { removeCalled++ @@ -2328,7 +2339,7 @@ func TestScProcessor_ProcessSCPaymentNotEnoughBalance(t *testing.T) { tx.GasPrice = 10 tx.GasLimit = 15 - acntSrc, _ := state.NewUserAccount(tx.SndAddr) + acntSrc := createAccount(tx.SndAddr) _ = acntSrc.AddToBalance(big.NewInt(45)) currBalance := acntSrc.GetBalance().Uint64() @@ -2382,7 +2393,7 @@ func TestScProcessor_ProcessSCPaymentWithNewFlags(t *testing.T) { return core.SafeMul(tx.GetGasPrice(), gasToUse) }, } - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsPenalizedTooMuchGasFlagEnabledField: true, } arguments.EnableEpochsHandler = enableEpochsHandlerStub @@ -2495,7 +2506,7 @@ func TestScProcessor_RefundGasToSender(t *testing.T) { arguments.EconomicsFee = &economicsmocks.EconomicsHandlerStub{MinGasPriceCalled: func() uint64 { return minGasPrice }} - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{} + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{} sc, err := NewSmartContractProcessor(arguments) require.NotNil(t, sc) require.Nil(t, err) @@ -2535,7 +2546,7 @@ func TestScProcessor_DoNotRefundGasToSenderForAsyncCall(t *testing.T) { arguments.EconomicsFee = &economicsmocks.EconomicsHandlerStub{MinGasPriceCalled: func() uint64 { return minGasPrice }} - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{} + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{} sc, err := NewSmartContractProcessor(arguments) require.NotNil(t, sc) require.Nil(t, err) @@ -2618,7 +2629,7 @@ func TestScProcessor_processSCOutputAccounts(t *testing.T) { outputAccounts = append(outputAccounts, outacc1) testAddr := outaddress - testAcc, _ := state.NewUserAccount(testAddr) + testAcc := createAccount(testAddr) accountsDB.LoadAccountCalled = func(address []byte) (handler vmcommon.AccountHandler, e error) { if bytes.Equal(address, testAddr) { @@ -2645,11 +2656,11 @@ func TestScProcessor_processSCOutputAccounts(t *testing.T) { outacc1.BalanceDelta = big.NewInt(int64(10)) tx.Value = big.NewInt(int64(10)) - currentBalance := testAcc.Balance.Uint64() + currentBalance := testAcc.GetBalance().Uint64() vmOutBalance := outacc1.BalanceDelta.Uint64() _, _, err = sc.processSCOutputAccounts(&vmcommon.VMOutput{}, vmData.DirectCall, outputAccounts, tx, []byte("hash")) require.Nil(t, err) - require.Equal(t, currentBalance+vmOutBalance, testAcc.Balance.Uint64()) + require.Equal(t, currentBalance+vmOutBalance, testAcc.GetBalance().Uint64()) } func TestScProcessor_processSCOutputAccountsNotInShard(t *testing.T) { @@ -2691,7 +2702,7 @@ func TestScProcessor_processSCOutputAccountsNotInShard(t *testing.T) { func TestScProcessor_CreateCrossShardTransactions(t *testing.T) { t.Parallel() - testAccounts, _ := state.NewUserAccount([]byte("address")) + testAccounts := createAccount([]byte("address")) accountsDB := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, err error) { return testAccounts, nil @@ -2738,7 +2749,7 @@ func TestScProcessor_CreateCrossShardTransactions(t *testing.T) { func TestScProcessor_CreateCrossShardTransactionsWithAsyncCalls(t *testing.T) { t.Parallel() - testAccounts, _ := state.NewUserAccount([]byte("address")) + testAccounts := createAccount([]byte("address")) accountsDB := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, err error) { return testAccounts, nil @@ -2749,7 +2760,7 @@ func TestScProcessor_CreateCrossShardTransactionsWithAsyncCalls(t *testing.T) { } shardCoordinator := mock.NewMultiShardsCoordinatorMock(5) arguments := createMockSmartContractProcessorArguments() - enableEpochsHandler := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsFixAsyncCallBackArgsListFlagEnabledField: false, } arguments.EnableEpochsHandler = enableEpochsHandler @@ -2830,7 +2841,7 @@ func TestScProcessor_CreateCrossShardTransactionsWithAsyncCalls(t *testing.T) { func TestScProcessor_CreateIntraShardTransactionsWithAsyncCalls(t *testing.T) { t.Parallel() - testAccounts, _ := state.NewUserAccount([]byte("address")) + testAccounts := createAccount([]byte("address")) accountsDB := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, err error) { return testAccounts, nil @@ -2843,7 +2854,7 @@ func TestScProcessor_CreateIntraShardTransactionsWithAsyncCalls(t *testing.T) { arguments := createMockSmartContractProcessorArguments() arguments.AccountsDB = accountsDB arguments.ShardCoordinator = shardCoordinator - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMultiESDTTransferFixOnCallBackFlagEnabledField: true, } sc, err := NewSmartContractProcessor(arguments) @@ -2971,13 +2982,15 @@ func TestScProcessor_ProcessSmartContractResultBadAccType(t *testing.T) { func TestScProcessor_ProcessSmartContractResultNotPayable(t *testing.T) { t.Parallel() - userAcc, _ := state.NewUserAccount([]byte("recv address")) + userAcc := createAccount([]byte("recv address")) accountsDB := &stateMock.AccountsStub{ - LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { - if bytes.Equal(address, userAcc.Address) { + LoadAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { + if bytes.Equal(address, userAcc.AddressBytes()) { return userAcc, nil } - return state.NewEmptyUserAccount(), nil + return &stateMock.AccountWrapMock{ + Balance: big.NewInt(0), + }, nil }, SaveAccountCalled: func(accountHandler vmcommon.AccountHandler) error { return nil @@ -3001,7 +3014,7 @@ func TestScProcessor_ProcessSmartContractResultNotPayable(t *testing.T) { require.Nil(t, err) scr := smartContractResult.SmartContractResult{ - RcvAddr: userAcc.Address, + RcvAddr: userAcc.AddressBytes(), SndAddr: []byte("snd addr"), Value: big.NewInt(0), } @@ -3026,8 +3039,8 @@ func TestScProcessor_ProcessSmartContractResultOutputBalanceNil(t *testing.T) { t.Parallel() accountsDB := &stateMock.AccountsStub{ - LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { - return state.NewUserAccount(address) + LoadAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { + return createAccount(address), nil }, SaveAccountCalled: func(accountHandler vmcommon.AccountHandler) error { return nil @@ -3056,7 +3069,7 @@ func TestScProcessor_ProcessSmartContractResultWithCode(t *testing.T) { putCodeCalled := 0 accountsDB := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { - return state.NewUserAccount(address) + return createAccount(address), nil }, SaveAccountCalled: func(accountHandler vmcommon.AccountHandler) error { putCodeCalled++ @@ -3091,7 +3104,7 @@ func TestScProcessor_ProcessSmartContractResultWithData(t *testing.T) { saveAccountCalled := 0 accountsDB := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { - return state.NewUserAccount(address) + return createAccount(address), nil }, SaveAccountCalled: func(accountHandler vmcommon.AccountHandler) error { saveAccountCalled++ @@ -3133,7 +3146,7 @@ func TestScProcessor_ProcessSmartContractResultDeploySCShouldError(t *testing.T) accountsDB := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { - return state.NewUserAccount(address) + return createAccount(address), nil }, SaveAccountCalled: func(accountHandler vmcommon.AccountHandler) error { return nil @@ -3169,7 +3182,7 @@ func TestScProcessor_ProcessSmartContractResultExecuteSC(t *testing.T) { t.Parallel() scAddress := []byte("000000000001234567890123456789012") - dstScAddress, _ := state.NewUserAccount(scAddress) + dstScAddress := createAccount(scAddress) dstScAddress.SetCode([]byte("code")) accountsDB := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { @@ -3231,7 +3244,7 @@ func TestScProcessor_ProcessSmartContractResultExecuteSCIfMetaAndBuiltIn(t *test t.Parallel() scAddress := []byte("000000000001234567890123456789012") - dstScAddress, _ := state.NewUserAccount(scAddress) + dstScAddress := createAccount(scAddress) dstScAddress.SetCode([]byte("code")) accountsDB := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { @@ -3275,7 +3288,7 @@ func TestScProcessor_ProcessSmartContractResultExecuteSCIfMetaAndBuiltIn(t *test return process.BuiltInFunctionCall, process.BuiltInFunctionCall }, } - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsSCDeployFlagEnabledField: true, } arguments.EnableEpochsHandler = enableEpochsHandlerStub @@ -3306,15 +3319,15 @@ func TestScProcessor_ProcessRelayedSCRValueBackToRelayer(t *testing.T) { t.Parallel() scAddress := []byte("000000000001234567890123456789012") - dstScAddress, _ := state.NewUserAccount(scAddress) + dstScAddress := createAccount(scAddress) dstScAddress.SetCode([]byte("code")) baseValue := big.NewInt(100) userAddress := []byte("111111111111234567890123456789012") - userAcc, _ := state.NewUserAccount(userAddress) + userAcc := createAccount(userAddress) _ = userAcc.AddToBalance(baseValue) relayedAddress := []byte("211111111111234567890123456789012") - relayedAcc, _ := state.NewUserAccount(relayedAddress) + relayedAcc := createAccount(relayedAddress) accountsDB := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (handler vmcommon.AccountHandler, e error) { @@ -3399,8 +3412,7 @@ func TestScProcessor_checkUpgradePermission(t *testing.T) { require.Equal(t, process.ErrUpgradeNotAllowed, err) // Create a contract, owned by Alice - contract, err := state.NewUserAccount([]byte("contract")) - require.Nil(t, err) + contract := createAccount([]byte("contract")) contract.SetOwnerAddress([]byte("alice")) // Not yet upgradeable contract.SetCodeMetadata([]byte{0, 0}) @@ -3433,7 +3445,7 @@ func TestScProcessor_penalizeUserIfNeededShouldWork(t *testing.T) { t.Parallel() arguments := createMockSmartContractProcessorArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsPenalizedTooMuchGasFlagEnabledField: true, } sc, _ := NewSmartContractProcessor(arguments) @@ -3514,7 +3526,7 @@ func TestScProcessor_penalizeUserIfNeededShouldWorkOnFlagActivation(t *testing.T func TestSCProcessor_createSCRWhenError(t *testing.T) { arguments := createMockSmartContractProcessorArguments() - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsSCDeployFlagEnabledField: true, IsPenalizedTooMuchGasFlagEnabledField: true, IsRepairCallbackFlagEnabledField: true, @@ -3577,7 +3589,7 @@ func TestGasLockedInSmartContractProcessor(t *testing.T) { return shardCoordinator.SelfId() + 1 } arguments.ShardCoordinator = shardCoordinator - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMultiESDTTransferFixOnCallBackFlagEnabledField: true, } sc, _ := NewSmartContractProcessor(arguments) @@ -3695,8 +3707,7 @@ func TestSmartContractProcessor_computeTotalConsumedFeeAndDevRwdWithDifferentSCC t.Parallel() scAccountAddress := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x1e, 0x2e, 0x61, 0x1a, 0x9c, 0xe1, 0xe0, 0xc8, 0xe3, 0x28, 0x3c, 0xcc, 0x7c, 0x1b, 0x0f, 0x46, 0x61, 0x91, 0x70, 0x79, 0xa7, 0x5c} - acc, err := state.NewUserAccount(scAccountAddress) - require.Nil(t, err) + acc := createAccount(scAccountAddress) require.NotNil(t, acc) arguments := createMockSmartContractProcessorArguments() @@ -3722,7 +3733,7 @@ func TestSmartContractProcessor_computeTotalConsumedFeeAndDevRwdWithDifferentSCC return acc, nil }, } - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsSCDeployFlagEnabledField: true, IsStakingV2FlagEnabledForActivationEpochCompletedField: true, } @@ -3785,8 +3796,7 @@ func TestSmartContractProcessor_finishSCExecutionV2(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - acc, err := state.NewUserAccount(scAccountAddress) - require.Nil(t, err) + acc := createAccount(scAccountAddress) require.NotNil(t, acc) arguments := createMockSmartContractProcessorArguments() @@ -3797,6 +3807,7 @@ func TestSmartContractProcessor_finishSCExecutionV2(t *testing.T) { // use a real fee handler args := createRealEconomicsDataArgs() + var err error arguments.EconomicsFee, err = economics.NewEconomicsData(*args) require.Nil(t, err) @@ -3812,7 +3823,7 @@ func TestSmartContractProcessor_finishSCExecutionV2(t *testing.T) { return acc, nil }, } - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsSCDeployFlagEnabledField: true, IsStakingV2FlagEnabledForActivationEpochCompletedField: true, } @@ -3827,7 +3838,7 @@ func TestSmartContractProcessor_finishSCExecutionV2(t *testing.T) { require.Nil(t, err) require.Equal(t, retcode, vmcommon.Ok) require.Nil(t, err) - require.Equal(t, expectedDevFees, acc.DeveloperReward) + require.Equal(t, expectedDevFees, acc.GetDeveloperReward()) require.Equal(t, expectedTotalFee, sc.txFeeHandler.GetAccumulatedFees()) require.Equal(t, expectedDevFees, sc.txFeeHandler.GetDeveloperFees()) }) @@ -3995,7 +4006,7 @@ func TestProcessIfErrorCheckBackwardsCompatibilityProcessTransactionFeeCalledSho }, } - arguments.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsSCDeployFlagEnabledField: true, IsCleanUpInformativeSCRsFlagEnabledField: true, IsOptimizeGasUsedInCrossMiniBlocksFlagEnabledField: true, @@ -4018,7 +4029,7 @@ func TestProcessSCRSizeTooBig(t *testing.T) { t.Parallel() arguments := createMockSmartContractProcessorArguments() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} arguments.EnableEpochsHandler = enableEpochsHandlerStub sc, _ := NewSmartContractProcessor(arguments) @@ -4072,7 +4083,7 @@ func TestCleanInformativeOnlySCRs(t *testing.T) { builtInFuncs := builtInFunctions.NewBuiltInFunctionContainer() arguments.BuiltInFunctions = builtInFuncs arguments.ArgsParser = NewArgumentParser() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} arguments.EnableEpochsHandler = enableEpochsHandlerStub sc, _ := NewSmartContractProcessor(arguments) @@ -4170,7 +4181,7 @@ func TestProcess_createCompletedTxEvent(t *testing.T) { scrWithRefund := &smartContractResult.SmartContractResult{Value: big.NewInt(10), PrevTxHash: scrHash, Data: []byte("@6f6b@aaffaa")} completedLogSaved = false - acntDst, _ := state.NewUserAccount(userAddress) + acntDst := createAccount(userAddress) err := sc.processSimpleSCR(scrWithRefund, []byte("scrHash"), acntDst) assert.Nil(t, err) assert.True(t, completedLogSaved) @@ -4218,7 +4229,7 @@ func createRealEconomicsDataArgs() *economics.ArgsNewEconomicsData { }, }, EpochNotifier: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsGasPriceModifierFlagEnabledField: true, }, BuiltInFunctionsCostHandler: &mock.BuiltInCostHandlerStub{}, @@ -4328,7 +4339,7 @@ func TestScProcessor_TooMuchGasProvidedMessage(t *testing.T) { t.Parallel() arguments := createMockSmartContractProcessorArguments() - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsSCDeployFlagEnabledField: true, IsPenalizedTooMuchGasFlagEnabledField: true, } diff --git a/process/sync/storageBootstrap/shardStorageBootstrapper_test.go b/process/sync/storageBootstrap/shardStorageBootstrapper_test.go index f99923dc214..f518b21b788 100644 --- a/process/sync/storageBootstrap/shardStorageBootstrapper_test.go +++ b/process/sync/storageBootstrap/shardStorageBootstrapper_test.go @@ -17,6 +17,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageMock "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -34,7 +35,7 @@ func TestShardStorageBootstrapper_LoadFromStorageShouldWork(t *testing.T) { wasCalledEpochNotifier := false savedLastRound := int64(0) - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} startRound := 4000 hdr := &block.Header{ Nonce: 3999, @@ -105,7 +106,7 @@ func TestShardStorageBootstrapper_LoadFromStorageShouldWork(t *testing.T) { return nil }, }, - Marshalizer: &testscommon.MarshalizerMock{}, + Marshalizer: &marshallerMock.MarshalizerMock{}, Store: &storageMock.ChainStorerStub{ GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { return blockStorerMock, nil diff --git a/process/sync/trieIterators/tokensSuppliesProcessor.go b/process/sync/trieIterators/tokensSuppliesProcessor.go index 632115eb214..2906b267267 100644 --- a/process/sync/trieIterators/tokensSuppliesProcessor.go +++ b/process/sync/trieIterators/tokensSuppliesProcessor.go @@ -16,7 +16,6 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dblookupext/esdtSupply" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/trie/keyBuilder" ) type tokensSuppliesProcessor struct { @@ -64,30 +63,27 @@ func (t *tokensSuppliesProcessor) HandleTrieAccountIteration(userAccount state.U return nil } - dataTrie := &common.TrieIteratorChannels{ + dataTrieChan := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - errDataTrieGet := userAccount.DataTrie().GetAllLeavesOnChannel(dataTrie, context.Background(), rh, keyBuilder.NewKeyBuilder()) + errDataTrieGet := userAccount.GetAllLeaves(dataTrieChan, context.Background()) if errDataTrieGet != nil { return fmt.Errorf("%w while getting all leaves for root hash %s", errDataTrieGet, hex.EncodeToString(rh)) } log.Trace("extractTokensSupplies - parsing account", "address", userAccount.AddressBytes()) esdtPrefix := []byte(core.ProtectedKeyPrefix + core.ESDTKeyIdentifier) - for userLeaf := range dataTrie.LeavesChan { + for userLeaf := range dataTrieChan.LeavesChan { if !bytes.HasPrefix(userLeaf.Key(), esdtPrefix) { continue } tokenKey := userLeaf.Key() lenESDTPrefix := len(esdtPrefix) - suffix := append(userLeaf.Key(), userAccount.AddressBytes()...) - value, errVal := userLeaf.ValueWithoutSuffix(suffix) - if errVal != nil { - return fmt.Errorf("%w while parsing the token with key %s", errVal, hex.EncodeToString(tokenKey)) - } + value := userLeaf.Value() + var esToken esdt.ESDigitalToken err := t.marshaller.Unmarshal(&esToken, value) if err != nil { @@ -99,7 +95,7 @@ func (t *tokensSuppliesProcessor) HandleTrieAccountIteration(userAccount state.U t.addToBalance(tokenID, nonce, esToken.Value) } - err := dataTrie.ErrChan.ReadFromChanNonBlocking() + err := dataTrieChan.ErrChan.ReadFromChanNonBlocking() if err != nil { return fmt.Errorf("%w while parsing errors from the trie iteration", err) } diff --git a/process/sync/trieIterators/tokensSuppliesProcessor_test.go b/process/sync/trieIterators/tokensSuppliesProcessor_test.go index 21eaf09f919..8effb5571a9 100644 --- a/process/sync/trieIterators/tokensSuppliesProcessor_test.go +++ b/process/sync/trieIterators/tokensSuppliesProcessor_test.go @@ -14,8 +14,10 @@ import ( coreEsdt "github.com/multiversx/mx-chain-go/dblookupext/esdtSupply" "github.com/multiversx/mx-chain-go/state" chainStorage "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/trie" @@ -26,7 +28,7 @@ import ( func getTokensSuppliesProcessorArgs() ArgsTokensSuppliesProcessor { return ArgsTokensSuppliesProcessor{ StorageService: &genericMocks.ChainStorerMock{}, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, } } @@ -69,6 +71,12 @@ func TestNewTokensSuppliesProcessor(t *testing.T) { func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { t.Parallel() + userAccArgs := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + t.Run("nil user account", func(t *testing.T) { t.Parallel() @@ -115,12 +123,16 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { tsp, _ := NewTokensSuppliesProcessor(args) expectedErr := errors.New("error") - userAcc, _ := state.NewUserAccount([]byte("addr")) + + userAcc, _ := state.NewUserAccount([]byte("addr"), userAccArgs) userAcc.SetRootHash([]byte("rootHash")) userAcc.SetDataTrie(&trie.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { return expectedErr }, + RootCalled: func() ([]byte, error) { + return []byte("rootHash"), nil + }, }) err := tsp.HandleTrieAccountIteration(userAcc) @@ -134,47 +146,22 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { args := getTokensSuppliesProcessorArgs() tsp, _ := NewTokensSuppliesProcessor(args) - userAcc, _ := state.NewUserAccount([]byte("addr")) + userAcc, _ := state.NewUserAccount([]byte("addr"), userAccArgs) userAcc.SetRootHash([]byte("rootHash")) userAcc.SetDataTrie(&trie.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("not a token key"), []byte("not a token value")) close(leavesChannels.LeavesChan) return nil }, - }) - - err := tsp.HandleTrieAccountIteration(userAcc) - require.NoError(t, err) - require.Empty(t, tsp.tokensSupplies) - }) - - t.Run("should return error if trie value cannot be extracted", func(t *testing.T) { - t.Parallel() - - args := getTokensSuppliesProcessorArgs() - tsp, _ := NewTokensSuppliesProcessor(args) - - userAcc, _ := state.NewUserAccount([]byte("addr")) - userAcc.SetRootHash([]byte("rootHash")) - userAcc.SetDataTrie(&trie.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { - esToken := &esdt.ESDigitalToken{ - Value: big.NewInt(37), - } - esBytes, _ := args.Marshaller.Marshal(esToken) - tknKey := []byte("ELRONDesdtTKN-00aacc") - leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage(tknKey, esBytes) - - close(leavesChannels.LeavesChan) - return nil + RootCalled: func() ([]byte, error) { + return []byte("rootHash"), nil }, }) err := tsp.HandleTrieAccountIteration(userAcc) - require.Error(t, err) - require.Contains(t, err.Error(), "suffix is not present or the position is incorrect") + require.NoError(t, err) require.Empty(t, tsp.tokensSupplies) }) @@ -184,10 +171,10 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { args := getTokensSuppliesProcessorArgs() tsp, _ := NewTokensSuppliesProcessor(args) - userAcc, _ := state.NewUserAccount(vmcommon.SystemAccountAddress) + userAcc, _ := state.NewUserAccount(vmcommon.SystemAccountAddress, userAccArgs) userAcc.SetRootHash([]byte("rootHash")) userAcc.SetDataTrie(&trie.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { esToken := &esdt.ESDigitalToken{ Value: big.NewInt(37), } @@ -213,10 +200,10 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { args := getTokensSuppliesProcessorArgs() tsp, _ := NewTokensSuppliesProcessor(args) - userAcc, _ := state.NewUserAccount([]byte("addr")) + userAcc, _ := state.NewUserAccount([]byte("addr"), userAccArgs) userAcc.SetRootHash([]byte("rootHash")) userAcc.SetDataTrie(&trie.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { esToken := &esdt.ESDigitalToken{ Value: big.NewInt(37), } @@ -224,7 +211,9 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { tknKey := []byte("ELRONDesdtTKN-00aacc") value := append(esBytes, tknKey...) value = append(value, []byte("addr")...) - leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage(tknKey, value) + leaf, err := leafParser.ParseLeaf(tknKey, value, 0) + require.Nil(t, err) + leavesChannels.LeavesChan <- leaf sft := &esdt.ESDigitalToken{ Value: big.NewInt(1), @@ -234,11 +223,16 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { sftKey = append(sftKey, big.NewInt(37).Bytes()...) value = append(sftBytes, sftKey...) value = append(value, []byte("addr")...) - leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage(sftKey, value) + leaf, err = leafParser.ParseLeaf(sftKey, value, 0) + require.Nil(t, err) + leavesChannels.LeavesChan <- leaf close(leavesChannels.LeavesChan) return nil }, + RootCalled: func() ([]byte, error) { + return []byte("rootHash"), nil + }, }) err := tsp.HandleTrieAccountIteration(userAcc) diff --git a/process/sync/trieIterators/trieAccountsIterator.go b/process/sync/trieIterators/trieAccountsIterator.go index e936d723e3e..c05992b7602 100644 --- a/process/sync/trieIterators/trieAccountsIterator.go +++ b/process/sync/trieIterators/trieAccountsIterator.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -58,7 +59,7 @@ func (t *trieAccountsIterator) Process(handlers ...TrieAccountIteratorHandler) e LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = t.accounts.GetAllLeaves(iteratorChannels, context.Background(), rootHash) + err = t.accounts.GetAllLeaves(iteratorChannels, context.Background(), rootHash, parsers.NewMainTrieLeafParser()) if err != nil { return err } diff --git a/process/sync/trieIterators/trieAccountsIterator_test.go b/process/sync/trieIterators/trieAccountsIterator_test.go index 8eb00d7a7f3..87712d52168 100644 --- a/process/sync/trieIterators/trieAccountsIterator_test.go +++ b/process/sync/trieIterators/trieAccountsIterator_test.go @@ -8,7 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/keyValStorage" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/require" @@ -16,7 +16,7 @@ import ( func getTrieAccountsIteratorArgs() ArgsTrieAccountsIterator { return ArgsTrieAccountsIterator{ - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, Accounts: &stateMock.AccountsStub{}, } } @@ -102,7 +102,7 @@ func TestTrieAccountsIterator_Process(t *testing.T) { RootHashCalled: func() ([]byte, error) { return []byte("rootHash"), nil }, - GetAllLeavesCalled: func(_ *common.TrieIteratorChannels, _ context.Context, _ []byte) error { + GetAllLeavesCalled: func(_ *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { return expectedErr }, } @@ -120,7 +120,7 @@ func TestTrieAccountsIterator_Process(t *testing.T) { RootHashCalled: func() ([]byte, error) { return []byte("rootHash"), nil }, - GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte) error { + GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { userAcc := &stateMock.AccountWrapMock{ RootHash: []byte("rootHash"), } @@ -147,7 +147,7 @@ func TestTrieAccountsIterator_Process(t *testing.T) { RootHashCalled: func() ([]byte, error) { return []byte("rootHash"), nil }, - GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte) error { + GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { userAcc := &stateMock.AccountWrapMock{ RootHash: []byte("rootHash"), } @@ -175,7 +175,7 @@ func TestTrieAccountsIterator_Process(t *testing.T) { RootHashCalled: func() ([]byte, error) { return []byte("rootHash"), nil }, - GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte) error { + GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { userAcc := &stateMock.AccountWrapMock{ RootHash: nil, } @@ -202,9 +202,10 @@ func TestTrieAccountsIterator_Process(t *testing.T) { RootHashCalled: func() ([]byte, error) { return []byte("rootHash"), nil }, - GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte) error { - userAcc := state.NewEmptyUserAccount() - userAcc.SetRootHash([]byte("root")) + GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { + userAcc := stateMock.AccountWrapMock{ + RootHash: []byte("root"), + } userAccBytes, _ := args.Marshaller.Marshal(userAcc) iter.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("addr"), userAccBytes) close(iter.LeavesChan) @@ -228,7 +229,7 @@ func TestTrieAccountsIterator_Process(t *testing.T) { RootHashCalled: func() ([]byte, error) { return []byte("rootHash"), nil }, - GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte) error { + GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { userAcc := &stateMock.AccountWrapMock{ RootHash: []byte("rootHash"), } @@ -261,7 +262,7 @@ func TestTrieAccountsIterator_Process(t *testing.T) { RootHashCalled: func() ([]byte, error) { return []byte("rootHash"), nil }, - GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte) error { + GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { userAcc := &stateMock.AccountWrapMock{ RootHash: []byte("rootHash"), } @@ -297,7 +298,7 @@ func TestTrieAccountsIterator_Process(t *testing.T) { RootHashCalled: func() ([]byte, error) { return []byte("rootHash"), nil }, - GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte) error { + GetAllLeavesCalled: func(iter *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { userAcc := &stateMock.AccountWrapMock{ RootHash: []byte("rootHash"), } diff --git a/process/transaction/baseProcess_test.go b/process/transaction/baseProcess_test.go index 04f58184562..04bc644c49f 100644 --- a/process/transaction/baseProcess_test.go +++ b/process/transaction/baseProcess_test.go @@ -13,8 +13,10 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/guardianMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" @@ -55,9 +57,9 @@ func Test_checkGuardedAccountUnguardedTxPermission(t *testing.T) { }, }, hasher: &hashingMocks.HasherMock{}, - marshalizer: &testscommon.MarshalizerMock{}, + marshalizer: &marshallerMock.MarshalizerMock{}, scProcessor: &testscommon.SCProcessorMock{}, - enableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsPenalizedTooMuchGasFlagEnabledField: true, }, txVersionChecker: &testscommon.TxVersionCheckerStub{}, @@ -115,9 +117,9 @@ func TestBaseTxProcessor_VerifyGuardian(t *testing.T) { }, }, hasher: &hashingMocks.HasherMock{}, - marshalizer: &testscommon.MarshalizerMock{}, + marshalizer: &marshallerMock.MarshalizerMock{}, scProcessor: &testscommon.SCProcessorMock{}, - enableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsPenalizedTooMuchGasFlagEnabledField: true, }, txVersionChecker: &testscommon.TxVersionCheckerStub{}, diff --git a/process/transaction/interceptedTransaction_test.go b/process/transaction/interceptedTransaction_test.go index 74efd6716b0..b2aa2e81526 100644 --- a/process/transaction/interceptedTransaction_test.go +++ b/process/transaction/interceptedTransaction_test.go @@ -22,6 +22,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -73,7 +74,7 @@ func createFreeTxFeeHandler() *economicsmocks.EconomicsHandlerStub { } func createInterceptedTxWithTxFeeHandlerAndVersionChecker(tx *dataTransaction.Transaction, txFeeHandler process.FeeHandler, txVerChecker *testscommon.TxVersionCheckerStub) (*transaction.InterceptedTransaction, error) { - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} txBuff, err := marshaller.Marshal(tx) if err != nil { return nil, err diff --git a/process/transaction/metaProcess_test.go b/process/transaction/metaProcess_test.go index 0db0be2af50..e3eeac2a36c 100644 --- a/process/transaction/metaProcess_test.go +++ b/process/transaction/metaProcess_test.go @@ -15,8 +15,10 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/guardianMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/vm" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -35,13 +37,23 @@ func createMockNewMetaTxArgs() txproc.ArgsNewMetaTxProcessor { ScProcessor: &testscommon.SCProcessorMock{}, TxTypeHandler: &testscommon.TxTypeHandlerMock{}, EconomicsFee: createFreeTxFeeHandler(), - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, GuardianChecker: &guardianMocks.GuardedAccountHandlerStub{}, TxVersionChecker: &testscommon.TxVersionCheckerStub{}, } return args } +func createUserAcc(address []byte) state.UserAccountHandler { + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + acc, _ := state.NewUserAccount(address, argsAccCreation) + return acc +} + // ------- NewMetaTxProcessor func TestNewMetaTxProcessor_NilAccountsShouldErr(t *testing.T) { @@ -160,10 +172,8 @@ func TestMetaTxProcessor_ProcessCheckNotPassShouldErr(t *testing.T) { tx.RcvAddr = []byte("DST") tx.Value = big.NewInt(45) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -171,7 +181,7 @@ func TestMetaTxProcessor_ProcessCheckNotPassShouldErr(t *testing.T) { args.Accounts = adb txProc, _ := txproc.NewMetaTxProcessor(args) - _, err = txProc.ProcessTransaction(&tx) + _, err := txProc.ProcessTransaction(&tx) assert.Equal(t, process.ErrHigherNonceInTransaction, err) } @@ -186,10 +196,8 @@ func TestMetaTxProcessor_ProcessMoveBalancesShouldCallProcessIfError(t *testing. tx.RcvAddr = []byte("DST") tx.Value = big.NewInt(0) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) adb.SaveAccountCalled = func(account vmcommon.AccountHandler) error { @@ -208,7 +216,7 @@ func TestMetaTxProcessor_ProcessMoveBalancesShouldCallProcessIfError(t *testing. } txProc, _ := txproc.NewMetaTxProcessor(args) - _, err = txProc.ProcessTransaction(&tx) + _, err := txProc.ProcessTransaction(&tx) assert.Equal(t, nil, err) assert.True(t, called) } @@ -226,13 +234,10 @@ func TestMetaTxProcessor_ProcessTransactionScTxShouldWork(t *testing.T) { tx.GasPrice = 1 tx.GasLimit = 1 - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Balance = big.NewInt(46) + _ = acntSrc.AddToBalance(big.NewInt(46)) acntDst.SetCode([]byte{65}) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -258,7 +263,7 @@ func TestMetaTxProcessor_ProcessTransactionScTxShouldWork(t *testing.T) { } txProc, _ := txproc.NewMetaTxProcessor(args) - _, err = txProc.ProcessTransaction(&tx) + _, err := txProc.ProcessTransaction(&tx) assert.Nil(t, err) assert.True(t, wasCalled) assert.Equal(t, 0, saveAccountCalled) @@ -275,11 +280,9 @@ func TestMetaTxProcessor_ProcessTransactionScTxShouldReturnErrWhenExecutionFails tx.RcvAddr = generateRandomByteSlice(createMockPubKeyConverter().Len()) tx.Value = big.NewInt(45) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntSrc.Balance = big.NewInt(45) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(45)) + acntDst := createUserAcc(tx.RcvAddr) acntDst.SetCode([]byte{65}) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -306,7 +309,7 @@ func TestMetaTxProcessor_ProcessTransactionScTxShouldReturnErrWhenExecutionFails } txProc, _ := txproc.NewMetaTxProcessor(args) - _, err = txProc.ProcessTransaction(&tx) + _, err := txProc.ProcessTransaction(&tx) assert.Equal(t, process.ErrNoVM, err) assert.True(t, wasCalled) assert.Equal(t, 0, saveAccountCalled) @@ -333,11 +336,9 @@ func TestMetaTxProcessor_ProcessTransactionScTxShouldNotBeCalledWhenAdrDstIsNotI return 0 } - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntSrc.Balance = big.NewInt(45) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(45)) + acntDst := createUserAcc(tx.RcvAddr) acntDst.SetCode([]byte{65}) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -365,7 +366,7 @@ func TestMetaTxProcessor_ProcessTransactionScTxShouldNotBeCalledWhenAdrDstIsNotI BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTMetadataContinuousCleanupFlagEnabledField: true, }, } @@ -378,7 +379,7 @@ func TestMetaTxProcessor_ProcessTransactionScTxShouldNotBeCalledWhenAdrDstIsNotI args.ShardCoordinator = shardCoordinator txProc, _ := txproc.NewMetaTxProcessor(args) - _, err = txProc.ProcessTransaction(&tx) + _, err := txProc.ProcessTransaction(&tx) assert.Equal(t, nil, err) assert.False(t, wasCalled) assert.True(t, calledIfError) @@ -397,13 +398,10 @@ func TestMetaTxProcessor_ProcessTransactionBuiltInCallTxShouldWork(t *testing.T) tx.GasPrice = 1 tx.GasLimit = 1 - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Balance = big.NewInt(46) + _ = acntSrc.AddToBalance(big.NewInt(46)) acntDst.SetCode([]byte{65}) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -427,14 +425,14 @@ func TestMetaTxProcessor_ProcessTransactionBuiltInCallTxShouldWork(t *testing.T) return process.BuiltInFunctionCall, process.BuiltInFunctionCall }, } - enableEpochsHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsBuiltInFunctionOnMetaFlagEnabledField: false, IsESDTFlagEnabledField: true, } args.EnableEpochsHandler = enableEpochsHandlerStub txProc, _ := txproc.NewMetaTxProcessor(args) - _, err = txProc.ProcessTransaction(&tx) + _, err := txProc.ProcessTransaction(&tx) assert.Nil(t, err) assert.True(t, wasCalled) assert.Equal(t, 0, saveAccountCalled) @@ -465,7 +463,12 @@ func TestMetaTxProcessor_ProcessTransactionWithInvalidUsernameShouldNotError(t * tx.GasPrice = 1 tx.GasLimit = 1 - acntDst, err := state.NewUserAccount(tx.RcvAddr) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + acntDst, err := state.NewUserAccount(tx.RcvAddr, argsAccCreation) assert.Nil(t, err) called := false diff --git a/process/transaction/shardProcess_test.go b/process/transaction/shardProcess_test.go index 61c56203d6a..bbbd316171a 100644 --- a/process/transaction/shardProcess_test.go +++ b/process/transaction/shardProcess_test.go @@ -21,6 +21,7 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/guardianMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" @@ -85,7 +86,7 @@ func createArgsForTxProcessor() txproc.ArgsNewTxProcessor { BadTxForwarder: &mock.IntermediateTransactionHandlerMock{}, ArgsParser: &mock.ArgumentParserMock{}, ScrForwarder: &mock.IntermediateTransactionHandlerMock{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsPenalizedTooMuchGasFlagEnabledField: true, }, GuardianChecker: &guardianMocks.GuardedAccountHandlerStub{}, @@ -320,8 +321,8 @@ func TestTxProcessor_GetAccountsOkValsSrcShouldWork(t *testing.T) { adr1 := []byte{65} adr2 := []byte{67} - acnt1, _ := state.NewUserAccount(adr1) - acnt2, _ := state.NewUserAccount(adr2) + acnt1 := createUserAcc(adr1) + acnt2 := createUserAcc(adr2) adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { if bytes.Equal(address, adr1) { @@ -365,8 +366,8 @@ func TestTxProcessor_GetAccountsOkValsDsthouldWork(t *testing.T) { adr1 := []byte{65} adr2 := []byte{67} - acnt1, _ := state.NewUserAccount(adr1) - acnt2, _ := state.NewUserAccount(adr2) + acnt1 := createUserAcc(adr1) + acnt2 := createUserAcc(adr2) adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { if bytes.Equal(address, adr1) { @@ -408,8 +409,8 @@ func TestTxProcessor_GetAccountsOkValsShouldWork(t *testing.T) { adr1 := []byte{65} adr2 := []byte{67} - acnt1, _ := state.NewUserAccount(adr1) - acnt2, _ := state.NewUserAccount(adr2) + acnt1 := createUserAcc(adr1) + acnt2 := createUserAcc(adr2) adb := createAccountStub(adr1, adr2, acnt1, acnt2) @@ -429,8 +430,8 @@ func TestTxProcessor_GetSameAccountShouldWork(t *testing.T) { adr1 := []byte{65} adr2 := []byte{65} - acnt1, _ := state.NewUserAccount(adr1) - acnt2, _ := state.NewUserAccount(adr2) + acnt1 := createUserAcc(adr1) + acnt2 := createUserAcc(adr2) adb := createAccountStub(adr1, adr2, acnt1, acnt2) @@ -449,14 +450,13 @@ func TestTxProcessor_CheckTxValuesHigherNonceShouldErr(t *testing.T) { t.Parallel() adr1 := []byte{65} - acnt1, err := state.NewUserAccount(adr1) - assert.Nil(t, err) + acnt1 := createUserAcc(adr1) execTx := *createTxProcessor() - acnt1.Nonce = 6 + acnt1.IncreaseNonce(6) - err = execTx.CheckTxValues(&transaction.Transaction{Nonce: 7}, acnt1, nil, false) + err := execTx.CheckTxValues(&transaction.Transaction{Nonce: 7}, acnt1, nil, false) assert.Equal(t, process.ErrHigherNonceInTransaction, err) } @@ -464,14 +464,13 @@ func TestTxProcessor_CheckTxValuesLowerNonceShouldErr(t *testing.T) { t.Parallel() adr1 := []byte{65} - acnt1, err := state.NewUserAccount(adr1) - assert.Nil(t, err) + acnt1 := createUserAcc(adr1) execTx := *createTxProcessor() - acnt1.Nonce = 6 + acnt1.IncreaseNonce(6) - err = execTx.CheckTxValues(&transaction.Transaction{Nonce: 5}, acnt1, nil, false) + err := execTx.CheckTxValues(&transaction.Transaction{Nonce: 5}, acnt1, nil, false) assert.Equal(t, process.ErrLowerNonceInTransaction, err) } @@ -479,14 +478,13 @@ func TestTxProcessor_CheckTxValuesInsufficientFundsShouldErr(t *testing.T) { t.Parallel() adr1 := []byte{65} - acnt1, err := state.NewUserAccount(adr1) - assert.Nil(t, err) + acnt1 := createUserAcc(adr1) execTx := *createTxProcessor() - acnt1.Balance = big.NewInt(67) + _ = acnt1.AddToBalance(big.NewInt(67)) - err = execTx.CheckTxValues(&transaction.Transaction{Value: big.NewInt(68)}, acnt1, nil, false) + err := execTx.CheckTxValues(&transaction.Transaction{Value: big.NewInt(68)}, acnt1, nil, false) assert.Equal(t, process.ErrInsufficientFunds, err) } @@ -494,21 +492,19 @@ func TestTxProcessor_CheckTxValuesMismatchedSenderUsernamesShouldErr(t *testing. t.Parallel() adr1 := []byte{65} - senderAcc, err := state.NewUserAccount(adr1) - - assert.Nil(t, err) + senderAcc := createUserAcc(adr1) execTx := *createTxProcessor() - senderAcc.Balance = big.NewInt(67) - senderAcc.UserName = []byte("SRC") + _ = senderAcc.AddToBalance(big.NewInt(67)) + senderAcc.SetUserName([]byte("SRC")) tx := &transaction.Transaction{ Value: big.NewInt(10), SndUserName: []byte("notCorrect"), } - err = execTx.CheckTxValues(tx, senderAcc, nil, false) + err := execTx.CheckTxValues(tx, senderAcc, nil, false) assert.Equal(t, process.ErrUserNameDoesNotMatch, err) } @@ -516,21 +512,19 @@ func TestTxProcessor_CheckTxValuesMismatchedReceiverUsernamesShouldErr(t *testin t.Parallel() adr1 := []byte{65} - receiverAcc, err := state.NewUserAccount(adr1) - - assert.Nil(t, err) + receiverAcc := createUserAcc(adr1) execTx := *createTxProcessor() - receiverAcc.Balance = big.NewInt(67) - receiverAcc.UserName = []byte("RECV") + _ = receiverAcc.AddToBalance(big.NewInt(67)) + receiverAcc.SetUserName([]byte("RECV")) tx := &transaction.Transaction{ Value: big.NewInt(10), RcvUserName: []byte("notCorrect"), } - err = execTx.CheckTxValues(tx, nil, receiverAcc, false) + err := execTx.CheckTxValues(tx, nil, receiverAcc, false) assert.Equal(t, process.ErrUserNameDoesNotMatchInCrossShardTx, err) } @@ -538,26 +532,24 @@ func TestTxProcessor_CheckTxValuesCorrectUserNamesShouldWork(t *testing.T) { t.Parallel() adr1 := []byte{65} - senderAcc, err := state.NewUserAccount(adr1) - assert.Nil(t, err) + senderAcc := createUserAcc(adr1) adr2 := []byte{66} - recvAcc, err := state.NewUserAccount(adr2) - assert.Nil(t, err) + recvAcc := createUserAcc(adr2) execTx := *createTxProcessor() - senderAcc.Balance = big.NewInt(67) - senderAcc.UserName = []byte("SRC") - recvAcc.UserName = []byte("RECV") + _ = senderAcc.AddToBalance(big.NewInt(67)) + senderAcc.SetUserName([]byte("SRC")) + recvAcc.SetUserName([]byte("RECV")) tx := &transaction.Transaction{ Value: big.NewInt(10), - SndUserName: senderAcc.UserName, - RcvUserName: recvAcc.UserName, + SndUserName: senderAcc.GetUserName(), + RcvUserName: recvAcc.GetUserName(), } - err = execTx.CheckTxValues(tx, senderAcc, recvAcc, false) + err := execTx.CheckTxValues(tx, senderAcc, recvAcc, false) assert.Nil(t, err) } @@ -565,14 +557,13 @@ func TestTxProcessor_CheckTxValuesOkValsShouldErr(t *testing.T) { t.Parallel() adr1 := []byte{65} - acnt1, err := state.NewUserAccount(adr1) - assert.Nil(t, err) + acnt1 := createUserAcc(adr1) execTx := *createTxProcessor() - acnt1.Balance = big.NewInt(67) + _ = acnt1.AddToBalance(big.NewInt(67)) - err = execTx.CheckTxValues(&transaction.Transaction{Value: big.NewInt(67)}, acnt1, nil, false) + err := execTx.CheckTxValues(&transaction.Transaction{Value: big.NewInt(67)}, acnt1, nil, false) assert.Nil(t, err) } @@ -582,15 +573,14 @@ func TestTxProcessor_IncreaseNonceOkValsShouldWork(t *testing.T) { t.Parallel() adrSrc := []byte{65} - acntSrc, err := state.NewUserAccount(adrSrc) - assert.Nil(t, err) + acntSrc := createUserAcc(adrSrc) execTx := *createTxProcessor() - acntSrc.Nonce = 45 + acntSrc.IncreaseNonce(45) execTx.IncreaseNonce(acntSrc) - assert.Equal(t, uint64(46), acntSrc.Nonce) + assert.Equal(t, uint64(46), acntSrc.GetNonce()) } //------- ProcessTransaction @@ -633,10 +623,8 @@ func TestTxProcessor_ProcessCheckNotPassShouldErr(t *testing.T) { tx.RcvAddr = []byte("DST") tx.Value = big.NewInt(45) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -644,7 +632,7 @@ func TestTxProcessor_ProcessCheckNotPassShouldErr(t *testing.T) { args.Accounts = adb execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Equal(t, process.ErrHigherNonceInTransaction, err) } @@ -669,10 +657,8 @@ func TestTxProcessor_ProcessWithTxFeeHandlerCheckErrorShouldErr(t *testing.T) { tx.RcvAddr = make([]byte, 32) tx.Value = big.NewInt(0) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -686,7 +672,7 @@ func TestTxProcessor_ProcessWithTxFeeHandlerCheckErrorShouldErr(t *testing.T) { }} execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Equal(t, expectedError, err) } @@ -721,12 +707,10 @@ func TestTxProcessor_ProcessWithTxFeeHandlerInsufficientFeeShouldErr(t *testing. tx.RcvAddr = make([]byte, 32) tx.Value = big.NewInt(0) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Balance = big.NewInt(9) + _ = acntSrc.AddToBalance(big.NewInt(9)) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -735,12 +719,12 @@ func TestTxProcessor_ProcessWithTxFeeHandlerInsufficientFeeShouldErr(t *testing. args.EconomicsFee = &economicsmocks.EconomicsHandlerStub{ ComputeTxFeeCalled: func(tx data.TransactionWithFeeHandler) *big.Int { - return big.NewInt(0).Add(acntSrc.Balance, big.NewInt(1)) + return big.NewInt(0).Add(acntSrc.GetBalance(), big.NewInt(1)) }} execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.True(t, errors.Is(err, process.ErrInsufficientFee)) } @@ -753,12 +737,10 @@ func TestTxProcessor_ProcessWithInsufficientFundsShouldCreateReceiptErr(t *testi tx.RcvAddr = make([]byte, 32) tx.Value = big.NewInt(0) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Balance = big.NewInt(9) + _ = acntSrc.AddToBalance(big.NewInt(9)) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -772,9 +754,9 @@ func TestTxProcessor_ProcessWithInsufficientFundsShouldCreateReceiptErr(t *testi execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Equal(t, process.ErrFailedTransaction, err) - assert.Equal(t, uint64(1), acntSrc.Nonce) + assert.Equal(t, uint64(1), acntSrc.GetNonce()) } func TestTxProcessor_ProcessWithUsernameMismatchCreateReceiptErr(t *testing.T) { @@ -786,12 +768,10 @@ func TestTxProcessor_ProcessWithUsernameMismatchCreateReceiptErr(t *testing.T) { tx.RcvAddr = make([]byte, 32) tx.Value = big.NewInt(0) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Balance = big.NewInt(9) + _ = acntSrc.AddToBalance(big.NewInt(9)) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -805,7 +785,7 @@ func TestTxProcessor_ProcessWithUsernameMismatchCreateReceiptErr(t *testing.T) { execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Nil(t, err) } @@ -818,12 +798,10 @@ func TestTxProcessor_ProcessWithUsernameMismatchAndSCProcessErrorShouldError(t * tx.RcvAddr = make([]byte, 32) tx.Value = big.NewInt(0) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Balance = big.NewInt(9) + _ = acntSrc.AddToBalance(big.NewInt(9)) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -845,7 +823,7 @@ func TestTxProcessor_ProcessWithUsernameMismatchAndSCProcessErrorShouldError(t * execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Equal(t, expectedError, err) } @@ -865,12 +843,10 @@ func TestTxProcessor_ProcessMoveBalanceToSmartPayableContract(t *testing.T) { return 0 } - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntDst.CodeMetadata = []byte{0, vmcommon.MetadataPayable} + acntDst.SetCodeMetadata([]byte{0, vmcommon.MetadataPayable}) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) adb.SaveAccountCalled = func(account vmcommon.AccountHandler) error { @@ -883,7 +859,7 @@ func TestTxProcessor_ProcessMoveBalanceToSmartPayableContract(t *testing.T) { args.ShardCoordinator = shardCoordinator execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Nil(t, err) assert.Equal(t, 2, saveAccountCalled) } @@ -906,10 +882,8 @@ func testProcessCheck(t *testing.T, nonce uint64, value *big.Int) { return 0 } - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) adb.SaveAccountCalled = func(account vmcommon.AccountHandler) error { @@ -922,7 +896,7 @@ func testProcessCheck(t *testing.T, nonce uint64, value *big.Int) { args.ShardCoordinator = shardCoordinator execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Nil(t, err) assert.Equal(t, 1, saveAccountCalled) } @@ -938,10 +912,8 @@ func TestTxProcessor_ProcessMoveBalancesShouldWork(t *testing.T) { tx.RcvAddr = []byte("DST") tx.Value = big.NewInt(0) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) adb.SaveAccountCalled = func(account vmcommon.AccountHandler) error { @@ -953,7 +925,7 @@ func TestTxProcessor_ProcessMoveBalancesShouldWork(t *testing.T) { args.Accounts = adb execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Nil(t, err) assert.Equal(t, 2, saveAccountCalled) } @@ -969,14 +941,12 @@ func TestTxProcessor_ProcessOkValsShouldWork(t *testing.T) { tx.RcvAddr = []byte("DST") tx.Value = big.NewInt(61) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Nonce = 4 - acntSrc.Balance = big.NewInt(90) - acntDst.Balance = big.NewInt(10) + acntSrc.IncreaseNonce(4) + _ = acntSrc.AddToBalance(big.NewInt(90)) + _ = acntDst.AddToBalance(big.NewInt(10)) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) adb.SaveAccountCalled = func(account vmcommon.AccountHandler) error { @@ -988,11 +958,11 @@ func TestTxProcessor_ProcessOkValsShouldWork(t *testing.T) { args.Accounts = adb execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Nil(t, err) - assert.Equal(t, uint64(5), acntSrc.Nonce) - assert.Equal(t, big.NewInt(29), acntSrc.Balance) - assert.Equal(t, big.NewInt(71), acntDst.Balance) + assert.Equal(t, uint64(5), acntSrc.GetNonce()) + assert.Equal(t, big.NewInt(29), acntSrc.GetBalance()) + assert.Equal(t, big.NewInt(71), acntDst.GetBalance()) assert.Equal(t, 2, saveAccountCalled) } @@ -1007,14 +977,12 @@ func TestTxProcessor_MoveBalanceWithFeesShouldWork(t *testing.T) { tx.GasPrice = 2 tx.GasLimit = 2 - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Nonce = 4 - acntSrc.Balance = big.NewInt(90) - acntDst.Balance = big.NewInt(10) + acntSrc.IncreaseNonce(4) + _ = acntSrc.AddToBalance(big.NewInt(90)) + _ = acntDst.AddToBalance(big.NewInt(10)) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) adb.SaveAccountCalled = func(account vmcommon.AccountHandler) error { @@ -1037,11 +1005,11 @@ func TestTxProcessor_MoveBalanceWithFeesShouldWork(t *testing.T) { args.EconomicsFee = feeHandler execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Nil(t, err) - assert.Equal(t, uint64(5), acntSrc.Nonce) - assert.Equal(t, big.NewInt(13), acntSrc.Balance) - assert.Equal(t, big.NewInt(71), acntDst.Balance) + assert.Equal(t, uint64(5), acntSrc.GetNonce()) + assert.Equal(t, big.NewInt(13), acntSrc.GetBalance()) + assert.Equal(t, big.NewInt(71), acntDst.GetBalance()) assert.Equal(t, 2, saveAccountCalled) } @@ -1057,13 +1025,10 @@ func TestTxProcessor_ProcessTransactionScDeployTxShouldWork(t *testing.T) { tx.GasPrice = 1 tx.GasLimit = 1 - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Balance = big.NewInt(46) + _ = acntSrc.AddToBalance(big.NewInt(46)) acntDst.SetCode([]byte{65}) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -1090,7 +1055,7 @@ func TestTxProcessor_ProcessTransactionScDeployTxShouldWork(t *testing.T) { } execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Nil(t, err) assert.True(t, wasCalled) assert.Equal(t, 0, saveAccountCalled) @@ -1108,13 +1073,10 @@ func TestTxProcessor_ProcessTransactionBuiltInFunctionCallShouldWork(t *testing. tx.GasPrice = 1 tx.GasLimit = 1 - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Balance = big.NewInt(46) + _ = acntSrc.AddToBalance(big.NewInt(46)) acntDst.SetCode([]byte{65}) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -1141,7 +1103,7 @@ func TestTxProcessor_ProcessTransactionBuiltInFunctionCallShouldWork(t *testing. } execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Nil(t, err) assert.True(t, wasCalled) assert.Equal(t, 0, saveAccountCalled) @@ -1159,13 +1121,10 @@ func TestTxProcessor_ProcessTransactionScTxShouldWork(t *testing.T) { tx.GasPrice = 1 tx.GasLimit = 1 - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + acntDst := createUserAcc(tx.RcvAddr) - acntSrc.Balance = big.NewInt(46) + _ = acntSrc.AddToBalance(big.NewInt(46)) acntDst.SetCode([]byte{65}) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -1192,7 +1151,7 @@ func TestTxProcessor_ProcessTransactionScTxShouldWork(t *testing.T) { } execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Nil(t, err) assert.True(t, wasCalled) assert.Equal(t, 0, saveAccountCalled) @@ -1209,11 +1168,9 @@ func TestTxProcessor_ProcessTransactionScTxShouldReturnErrWhenExecutionFails(t * tx.RcvAddr = generateRandomByteSlice(createMockPubKeyConverter().Len()) tx.Value = big.NewInt(45) - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntSrc.Balance = big.NewInt(45) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(45)) + acntDst := createUserAcc(tx.RcvAddr) acntDst.SetCode([]byte{65}) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -1240,7 +1197,7 @@ func TestTxProcessor_ProcessTransactionScTxShouldReturnErrWhenExecutionFails(t * } execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Equal(t, process.ErrNoVM, err) assert.True(t, wasCalled) assert.Equal(t, 0, saveAccountCalled) @@ -1266,11 +1223,9 @@ func TestTxProcessor_ProcessTransactionScTxShouldNotBeCalledWhenAdrDstIsNotInNod return 0 } - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntSrc.Balance = big.NewInt(45) - acntDst, err := state.NewUserAccount(tx.RcvAddr) - assert.Nil(t, err) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(45)) + acntDst := createUserAcc(tx.RcvAddr) acntDst.SetCode([]byte{65}) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, acntDst) @@ -1293,7 +1248,7 @@ func TestTxProcessor_ProcessTransactionScTxShouldNotBeCalledWhenAdrDstIsNotInNod BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTMetadataContinuousCleanupFlagEnabledField: true, }, } @@ -1306,7 +1261,7 @@ func TestTxProcessor_ProcessTransactionScTxShouldNotBeCalledWhenAdrDstIsNotInNod args.TxTypeHandler = computeType execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Nil(t, err) assert.False(t, wasCalled) assert.Equal(t, 1, saveAccountCalled) @@ -1537,9 +1492,8 @@ func TestTxProcessor_ProcessTransactionShouldReturnErrForInvalidMetaTx(t *testin tx.GasPrice = 1 tx.GasLimit = 1 - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntSrc.Balance = big.NewInt(100000000) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100000000)) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, nil) scProcessorMock := &testscommon.SCProcessorMock{ @@ -1562,15 +1516,15 @@ func TestTxProcessor_ProcessTransactionShouldReturnErrForInvalidMetaTx(t *testin return process.MoveBalance, process.MoveBalance }, } - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMetaProtectionFlagEnabledField: true, } execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Equal(t, err, process.ErrFailedTransaction) - assert.Equal(t, uint64(1), acntSrc.Nonce) - assert.Equal(t, uint64(99999999), acntSrc.Balance.Uint64()) + assert.Equal(t, uint64(1), acntSrc.GetNonce()) + assert.Equal(t, uint64(99999999), acntSrc.GetBalance().Uint64()) tx.Data = []byte("something") tx.Nonce = tx.Nonce + 1 @@ -1594,9 +1548,8 @@ func TestTxProcessor_ProcessTransactionShouldTreatAsInvalidTxIfTxTypeIsWrong(t * tx.GasPrice = 1 tx.GasLimit = 1 - acntSrc, err := state.NewUserAccount(tx.SndAddr) - assert.Nil(t, err) - acntSrc.Balance = big.NewInt(46) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(46)) adb := createAccountStub(tx.SndAddr, tx.RcvAddr, acntSrc, nil) shardC, _ := sharding.NewMultiShardCoordinator(5, 3) @@ -1615,10 +1568,10 @@ func TestTxProcessor_ProcessTransactionShouldTreatAsInvalidTxIfTxTypeIsWrong(t * } execTx, _ := txproc.NewTxProcessor(args) - _, err = execTx.ProcessTransaction(&tx) + _, err := execTx.ProcessTransaction(&tx) assert.Equal(t, err, process.ErrFailedTransaction) - assert.Equal(t, uint64(1), acntSrc.Nonce) - assert.Equal(t, uint64(45), acntSrc.Balance.Uint64()) + assert.Equal(t, uint64(1), acntSrc.GetNonce()) + assert.Equal(t, uint64(45), acntSrc.GetBalance().Uint64()) } func TestTxProcessor_ProcessRelayedTransactionV2NotActiveShouldErr(t *testing.T) { @@ -1651,13 +1604,13 @@ func TestTxProcessor_ProcessRelayedTransactionV2NotActiveShouldErr(t *testing.T) "@" + "01a2") - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) - acntFinal, _ := state.NewUserAccount(userTxDest) - acntFinal.Balance = big.NewInt(10) + acntFinal := createUserAcc(userTxDest) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -1683,7 +1636,7 @@ func TestTxProcessor_ProcessRelayedTransactionV2NotActiveShouldErr(t *testing.T) BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTMetadataContinuousCleanupFlagEnabledField: true, }, } @@ -1733,13 +1686,13 @@ func TestTxProcessor_ProcessRelayedTransactionV2WithValueShouldErr(t *testing.T) "@" + "01a2") - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) - acntFinal, _ := state.NewUserAccount(userTxDest) - acntFinal.Balance = big.NewInt(10) + acntFinal := createUserAcc(userTxDest) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -1765,7 +1718,7 @@ func TestTxProcessor_ProcessRelayedTransactionV2WithValueShouldErr(t *testing.T) BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTMetadataContinuousCleanupFlagEnabledField: true, }, } @@ -1815,13 +1768,13 @@ func TestTxProcessor_ProcessRelayedTransactionV2ArgsParserShouldErr(t *testing.T "@" + "01a2") - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) - acntFinal, _ := state.NewUserAccount(userTxDest) - acntFinal.Balance = big.NewInt(10) + acntFinal := createUserAcc(userTxDest) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -1847,7 +1800,7 @@ func TestTxProcessor_ProcessRelayedTransactionV2ArgsParserShouldErr(t *testing.T BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTMetadataContinuousCleanupFlagEnabledField: true, }, } @@ -1904,13 +1857,13 @@ func TestTxProcessor_ProcessRelayedTransactionV2InvalidParamCountShouldErr(t *te "@" + "1010") - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) - acntFinal, _ := state.NewUserAccount(userTxDest) - acntFinal.Balance = big.NewInt(10) + acntFinal := createUserAcc(userTxDest) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -1936,7 +1889,7 @@ func TestTxProcessor_ProcessRelayedTransactionV2InvalidParamCountShouldErr(t *te BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTMetadataContinuousCleanupFlagEnabledField: true, }, } @@ -1986,13 +1939,13 @@ func TestTxProcessor_ProcessRelayedTransactionV2(t *testing.T) { "@" + "01a2") - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) - acntFinal, _ := state.NewUserAccount(userTxDest) - acntFinal.Balance = big.NewInt(10) + acntFinal := createUserAcc(userTxDest) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2018,7 +1971,7 @@ func TestTxProcessor_ProcessRelayedTransactionV2(t *testing.T) { BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTMetadataContinuousCleanupFlagEnabledField: true, }, } @@ -2031,7 +1984,7 @@ func TestTxProcessor_ProcessRelayedTransactionV2(t *testing.T) { args.TxTypeHandler = txTypeHandler args.PubkeyConv = pubKeyConverter args.ArgsParser = smartContract.NewArgumentParser() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRelayedTransactionsV2FlagEnabledField: true, } execTx, _ := txproc.NewTxProcessor(args) @@ -2067,12 +2020,12 @@ func TestTxProcessor_ProcessRelayedTransaction(t *testing.T) { userTxMarshalled, _ := marshalizer.Marshal(userTx) tx.Data = []byte(core.RelayedTransaction + "@" + hex.EncodeToString(userTxMarshalled)) - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2098,7 +2051,7 @@ func TestTxProcessor_ProcessRelayedTransaction(t *testing.T) { BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTMetadataContinuousCleanupFlagEnabledField: true, }, } @@ -2111,7 +2064,7 @@ func TestTxProcessor_ProcessRelayedTransaction(t *testing.T) { args.TxTypeHandler = txTypeHandler args.PubkeyConv = pubKeyConverter args.ArgsParser = smartContract.NewArgumentParser() - args.EnableEpochsHandler = &testscommon.EnableEpochsHandlerStub{ + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsRelayedTransactionsFlagEnabledField: true, } execTx, _ := txproc.NewTxProcessor(args) @@ -2163,12 +2116,12 @@ func TestTxProcessor_ProcessRelayedTransactionArgsParserErrorShouldError(t *test return "", nil, parseError }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2226,12 +2179,12 @@ func TestTxProcessor_ProcessRelayedTransactionMultipleArgumentsShouldError(t *te return core.RelayedTransaction, [][]byte{[]byte("0"), []byte("1")}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2289,12 +2242,12 @@ func TestTxProcessor_ProcessRelayedTransactionFailUnMarshalInnerShouldError(t *t return core.RelayedTransaction, [][]byte{[]byte("0")}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2352,12 +2305,12 @@ func TestTxProcessor_ProcessRelayedTransactionDifferentSenderInInnerTxThanReceiv return core.RelayedTransaction, [][]byte{userTxMarshalled}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2415,12 +2368,12 @@ func TestTxProcessor_ProcessRelayedTransactionSmallerValueInnerTxShouldError(t * return core.RelayedTransaction, [][]byte{userTxMarshalled}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2478,12 +2431,12 @@ func TestTxProcessor_ProcessRelayedTransactionGasPriceMismatchShouldError(t *tes return core.RelayedTransaction, [][]byte{userTxMarshalled}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2541,12 +2494,12 @@ func TestTxProcessor_ProcessRelayedTransactionGasLimitMismatchShouldError(t *tes return core.RelayedTransaction, [][]byte{userTxMarshalled}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2600,12 +2553,12 @@ func TestTxProcessor_ProcessRelayedTransactionDisabled(t *testing.T) { userTxMarshalled, _ := marshalizer.Marshal(userTx) tx.Data = []byte(core.RelayedTransaction + "@" + hex.EncodeToString(userTxMarshalled)) - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(10) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(10) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(10)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(10)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2631,7 +2584,7 @@ func TestTxProcessor_ProcessRelayedTransactionDisabled(t *testing.T) { BuiltInFunctions: builtInFunctions.NewBuiltInFunctionContainer(), ArgumentParser: parsers.NewCallArgsParser(), ESDTTransferParser: esdtTransferParser, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTMetadataContinuousCleanupFlagEnabledField: true, }, } @@ -2678,8 +2631,8 @@ func TestTxProcessor_ConsumeMoveBalanceWithUserTx(t *testing.T) { } execTx, _ := txproc.NewTxProcessor(args) - acntSrc, _ := state.NewUserAccount([]byte("address")) - acntSrc.Balance = big.NewInt(100) + acntSrc := createUserAcc([]byte("address")) + _ = acntSrc.AddToBalance(big.NewInt(100)) originalTxHash := []byte("originalTxHash") userTx := &transaction.Transaction{ @@ -2691,7 +2644,7 @@ func TestTxProcessor_ConsumeMoveBalanceWithUserTx(t *testing.T) { err := execTx.ProcessMoveBalanceCostRelayedUserTx(userTx, &smartContractResult.SmartContractResult{}, acntSrc, originalTxHash) assert.Nil(t, err) - assert.Equal(t, acntSrc.Balance, big.NewInt(99)) + assert.Equal(t, acntSrc.GetBalance(), big.NewInt(99)) } func TestTxProcessor_IsCrossTxFromMeShouldWork(t *testing.T) { @@ -2738,12 +2691,12 @@ func TestTxProcessor_ProcessUserTxOfTypeRelayedShouldError(t *testing.T) { return core.RelayedTransaction, [][]byte{userTxMarshalled}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(100) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(100) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(100)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(100)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2802,12 +2755,12 @@ func TestTxProcessor_ProcessUserTxOfTypeMoveBalanceShouldWork(t *testing.T) { return core.RelayedTransaction, [][]byte{userTxMarshalled}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(100) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(100) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(100)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(100)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2866,12 +2819,12 @@ func TestTxProcessor_ProcessUserTxOfTypeSCDeploymentShouldWork(t *testing.T) { return core.RelayedTransaction, [][]byte{userTxMarshalled}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(100) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(100) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(100)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(100)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2930,12 +2883,12 @@ func TestTxProcessor_ProcessUserTxOfTypeSCInvokingShouldWork(t *testing.T) { return core.RelayedTransaction, [][]byte{userTxMarshalled}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(100) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(100) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(100)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(100)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -2994,12 +2947,12 @@ func TestTxProcessor_ProcessUserTxOfTypeBuiltInFunctionCallShouldWork(t *testing return core.RelayedTransaction, [][]byte{userTxMarshalled}, nil }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(100) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(100) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(100)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(100)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -3062,12 +3015,12 @@ func TestTxProcessor_ProcessUserTxErrNotPayableShouldFailRelayTx(t *testing.T) { return false, process.ErrAccountNotPayable }} - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(100) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(100) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(100)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(100)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -3132,12 +3085,12 @@ func TestTxProcessor_ProcessUserTxFailedBuiltInFunctionCall(t *testing.T) { }, } - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(100) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(100) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(100)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(100)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { @@ -3192,12 +3145,12 @@ func TestTxProcessor_ExecuteFailingRelayedTxShouldNotHaveNegativeFee(t *testing. args := createArgsForTxProcessor() - acntSrc, _ := state.NewUserAccount(tx.SndAddr) - acntSrc.Balance = big.NewInt(100) - acntDst, _ := state.NewUserAccount(tx.RcvAddr) - acntDst.Balance = big.NewInt(100) - acntFinal, _ := state.NewUserAccount(userTx.RcvAddr) - acntFinal.Balance = big.NewInt(100) + acntSrc := createUserAcc(tx.SndAddr) + _ = acntSrc.AddToBalance(big.NewInt(100)) + acntDst := createUserAcc(tx.RcvAddr) + _ = acntDst.AddToBalance(big.NewInt(100)) + acntFinal := createUserAcc(userTx.RcvAddr) + _ = acntFinal.AddToBalance(big.NewInt(100)) adb := &stateMock.AccountsStub{} adb.LoadAccountCalled = func(address []byte) (vmcommon.AccountHandler, error) { diff --git a/process/transaction/transactionCostEstimator_test.go b/process/transaction/transactionCostEstimator_test.go index 66338b83d9f..5325e123617 100644 --- a/process/transaction/transactionCostEstimator_test.go +++ b/process/transaction/transactionCostEstimator_test.go @@ -16,6 +16,7 @@ import ( txSimData "github.com/multiversx/mx-chain-go/process/txsimulator/data" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/require" @@ -30,7 +31,7 @@ func TestTransactionCostEstimator_NilTxTypeHandler(t *testing.T) { &mock.TransactionSimulatorStub{}, &stateMock.AccountsStub{}, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) require.Nil(t, tce) require.Equal(t, process.ErrNilTxTypeHandler, err) @@ -45,7 +46,7 @@ func TestTransactionCostEstimator_NilFeeHandlerShouldErr(t *testing.T) { &mock.TransactionSimulatorStub{}, &stateMock.AccountsStub{}, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) require.Nil(t, tce) require.Equal(t, process.ErrNilEconomicsFeeHandler, err) @@ -60,7 +61,7 @@ func TestTransactionCostEstimator_NilTransactionSimulatorShouldErr(t *testing.T) nil, &stateMock.AccountsStub{}, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) require.Nil(t, tce) require.Equal(t, txsimulator.ErrNilTxSimulatorProcessor, err) @@ -90,7 +91,7 @@ func TestTransactionCostEstimator_Ok(t *testing.T) { &mock.TransactionSimulatorStub{}, &stateMock.AccountsStub{}, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) require.Nil(t, err) require.False(t, check.IfNil(tce)) @@ -120,7 +121,7 @@ func TestComputeTransactionGasLimit_MoveBalance(t *testing.T) { return &stateMock.UserAccountStub{Balance: big.NewInt(100000)}, nil }, }, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) tx := &transaction.Transaction{} cost, err := tce.ComputeTransactionGasLimit(tx) @@ -153,7 +154,7 @@ func TestComputeTransactionGasLimit_MoveBalanceInvalidNonceShouldStillComputeCos return &stateMock.UserAccountStub{Balance: big.NewInt(100000)}, nil }, }, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) tx := &transaction.Transaction{} cost, err := tce.ComputeTransactionGasLimit(tx) @@ -186,7 +187,7 @@ func TestComputeTransactionGasLimit_BuiltInFunction(t *testing.T) { return &stateMock.UserAccountStub{Balance: big.NewInt(100000)}, nil }, }, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) tx := &transaction.Transaction{} cost, err := tce.ComputeTransactionGasLimit(tx) @@ -214,7 +215,7 @@ func TestComputeTransactionGasLimit_BuiltInFunctionShouldErr(t *testing.T) { return &stateMock.UserAccountStub{Balance: big.NewInt(100000)}, nil }, }, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) tx := &transaction.Transaction{} cost, err := tce.ComputeTransactionGasLimit(tx) @@ -241,7 +242,7 @@ func TestComputeTransactionGasLimit_NilVMOutput(t *testing.T) { return &stateMock.UserAccountStub{Balance: big.NewInt(100000)}, nil }, }, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) tx := &transaction.Transaction{} cost, err := tce.ComputeTransactionGasLimit(tx) @@ -272,7 +273,7 @@ func TestComputeTransactionGasLimit_RetCodeNotOk(t *testing.T) { return &stateMock.UserAccountStub{Balance: big.NewInt(100000)}, nil }, }, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) tx := &transaction.Transaction{} cost, err := tce.ComputeTransactionGasLimit(tx) @@ -293,7 +294,7 @@ func TestTransactionCostEstimator_RelayedTxShouldErr(t *testing.T) { &mock.TransactionSimulatorStub{}, &stateMock.AccountsStub{}, &mock.ShardCoordinatorStub{}, - &testscommon.EnableEpochsHandlerStub{}) + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) tx := &transaction.Transaction{} cost, err := tce.ComputeTransactionGasLimit(tx) diff --git a/process/txsSender/txsSender_test.go b/process/txsSender/txsSender_test.go index eb5b06105ce..50ee7af2876 100644 --- a/process/txsSender/txsSender_test.go +++ b/process/txsSender/txsSender_test.go @@ -23,6 +23,7 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -103,7 +104,7 @@ func TestNewTxsSenderWithAccumulator(t *testing.T) { func TestTxsSender_SendBulkTransactions(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} mutRecoveredTransactions := &sync.RWMutex{} recoveredTransactions := make(map[uint32][]*transaction.Transaction) shardCoordinator := testscommon.NewMultiShardsCoordinatorMock(2) @@ -178,9 +179,9 @@ func TestTxsSender_SendBulkTransactions(t *testing.T) { } }, } - dataPacker, _ := partitioning.NewSimpleDataPacker(&testscommon.MarshalizerMock{}) + dataPacker, _ := partitioning.NewSimpleDataPacker(&marshallerMock.MarshalizerMock{}) args := ArgsTxsSenderWithAccumulator{ - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, ShardCoordinator: shardCoordinator, NetworkMessenger: mes, DataPacker: dataPacker, @@ -251,7 +252,7 @@ func TestTxsSender_sendFromTxAccumulatorSendOneTxOneSCRExpectOnlyTxToBeSent(t *t } txMarshalled := []byte("txMarshalled") - marshaller := &testscommon.MarshalizerStub{ + marshaller := &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { ctMarshallCalled.Increment() require.Equal(t, tx, obj) @@ -330,7 +331,7 @@ func TestTxsSender_sendBulkTransactionsSendTwoTxsFailToMarshallOneExpectOnlyOneT } tx1Marshalled := []byte("tx1Marshalled") - marshaller := &testscommon.MarshalizerStub{ + marshaller := &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { switch ctMarshallCalled.Get() { case 0: @@ -423,7 +424,7 @@ func TestTxsSender_SendBulkTransactionsNoTxToProcessExpectError(t *testing.T) { } func generateMockArgsTxsSender() ArgsTxsSenderWithAccumulator { - marshaller := testscommon.MarshalizerMock{} + marshaller := marshallerMock.MarshalizerMock{} dataPacker, _ := partitioning.NewSimpleDataPacker(marshaller) accumulatorConfig := config.TxAccumulatorConfig{ MaxAllowedTimeInMilliseconds: 10, diff --git a/process/txsimulator/wrappedAccountsDB.go b/process/txsimulator/wrappedAccountsDB.go index 45b8b23fef3..4744e6bf02b 100644 --- a/process/txsimulator/wrappedAccountsDB.go +++ b/process/txsimulator/wrappedAccountsDB.go @@ -115,8 +115,8 @@ func (r *readOnlyAccountsDB) IsPruningEnabled() bool { } // GetAllLeaves will call the original accounts' function with the same name -func (r *readOnlyAccountsDB) GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { - return r.originalAccounts.GetAllLeaves(leavesChannels, ctx, rootHash) +func (r *readOnlyAccountsDB) GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error { + return r.originalAccounts.GetAllLeaves(leavesChannels, ctx, rootHash, trieLeafParser) } // RecreateAllTries will return an error which indicates that this operation is not supported diff --git a/process/txsimulator/wrappedAccountsDB_test.go b/process/txsimulator/wrappedAccountsDB_test.go index e83fe6a0d58..8666106d2d0 100644 --- a/process/txsimulator/wrappedAccountsDB_test.go +++ b/process/txsimulator/wrappedAccountsDB_test.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/require" @@ -153,7 +154,7 @@ func TestReadOnlyAccountsDB_ReadOperationsShouldWork(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder), ErrChan: errChan.NewErrChanWrapper(), } - err = roAccDb.GetAllLeaves(allLeaves, context.Background(), nil) + err = roAccDb.GetAllLeaves(allLeaves, context.Background(), nil, parsers.NewMainTrieLeafParser()) require.NoError(t, err) err = allLeaves.ErrChan.ReadFromChanNonBlocking() diff --git a/sharding/mock/enableEpochsHandlerMock.go b/sharding/mock/enableEpochsHandlerMock.go index dca1d41a6c7..4213f7733e3 100644 --- a/sharding/mock/enableEpochsHandlerMock.go +++ b/sharding/mock/enableEpochsHandlerMock.go @@ -591,6 +591,11 @@ func (mock *EnableEpochsHandlerMock) IsConsistentTokensValuesLengthCheckEnabled( return false } +// IsAutoBalanceDataTriesEnabled - +func (mock *EnableEpochsHandlerMock) IsAutoBalanceDataTriesEnabled() bool { + return false +} + // IsInterfaceNil returns true if there is no value under the interface func (mock *EnableEpochsHandlerMock) IsInterfaceNil() bool { return mock == nil diff --git a/state/accountsDB.go b/state/accountsDB.go index 5daea6408ab..acd233ae5a4 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/holders" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/multiversx/mx-chain-go/trie/storageMarker" @@ -435,7 +436,7 @@ func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error } if oldCodeEntry.NumReferences <= 1 { - err = adb.mainTrie.Update(oldCodeHash, nil) + err = adb.mainTrie.Delete(oldCodeHash) if err != nil { return nil, err } @@ -553,6 +554,9 @@ func (adb *AccountsDB) saveDataTrie(accountHandler baseAccountHandler) error { } adb.journalize(entry) + //TODO in order to avoid recomputing the root hash after every transaction for the same data trie, + // benchmark if it is better to cache the account and compute the rootHash only when the state is committed. + // For this to work, LoadAccount should check that cache first, and only after load from the trie. rootHash, err := accountHandler.DataTrie().RootHash() if err != nil { return err @@ -620,7 +624,7 @@ func (adb *AccountsDB) RemoveAccount(address []byte) error { "address", hex.EncodeToString(address), ) - return adb.mainTrie.Update(address, make([]byte, 0)) + return adb.mainTrie.Delete(address) } func (adb *AccountsDB) removeCodeAndDataTrie(acnt vmcommon.AccountHandler) error { @@ -1043,7 +1047,13 @@ func (adb *AccountsDB) RecreateAllTries(rootHash []byte) (map[string]common.Trie ErrChan: errChan.NewErrChanWrapper(), } mainTrie := adb.getMainTrie() - err := mainTrie.GetAllLeavesOnChannel(leavesChannels, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + err := mainTrie.GetAllLeavesOnChannel( + leavesChannels, + context.Background(), + rootHash, + keyBuilder.NewDisabledKeyBuilder(), + parsers.NewMainTrieLeafParser(), + ) if err != nil { return nil, err } @@ -1416,8 +1426,8 @@ func (adb *AccountsDB) IsPruningEnabled() bool { } // GetAllLeaves returns all the leaves from a given rootHash -func (adb *AccountsDB) GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { - return adb.getMainTrie().GetAllLeavesOnChannel(leavesChannels, ctx, rootHash, keyBuilder.NewKeyBuilder()) +func (adb *AccountsDB) GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error { + return adb.getMainTrie().GetAllLeavesOnChannel(leavesChannels, ctx, rootHash, keyBuilder.NewKeyBuilder(), trieLeafParser) } // Close will handle the closing of the underlying components @@ -1439,13 +1449,19 @@ func (adb *AccountsDB) GetStatsForRootHash(rootHash []byte) (common.TriesStatist return nil, fmt.Errorf("invalid trie, type is %T", mainTrie) } - collectStats(tr, stats, rootHash, "") + collectStats(tr, stats, rootHash, "", common.MainTrie) iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, leavesChannelSize), ErrChan: errChan.NewErrChanWrapper(), } - err := mainTrie.GetAllLeavesOnChannel(iteratorChannels, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + err := mainTrie.GetAllLeavesOnChannel( + iteratorChannels, + context.Background(), + rootHash, + keyBuilder.NewDisabledKeyBuilder(), + parsers.NewMainTrieLeafParser(), + ) if err != nil { return nil, err } @@ -1467,7 +1483,7 @@ func (adb *AccountsDB) GetStatsForRootHash(rootHash []byte) (common.TriesStatist return nil, err } - collectStats(tr, stats, account.RootHash, accountAddress) + collectStats(tr, stats, account.RootHash, accountAddress, common.DataTrie) } err = iteratorChannels.ErrChan.ReadFromChanNonBlocking() @@ -1478,13 +1494,19 @@ func (adb *AccountsDB) GetStatsForRootHash(rootHash []byte) (common.TriesStatist return stats, nil } -func collectStats(tr common.TrieStats, stats common.TriesStatisticsCollector, rootHash []byte, address string) { +func collectStats( + tr common.TrieStats, + stats common.TriesStatisticsCollector, + rootHash []byte, + address string, + trieType common.TrieType, +) { trieStats, err := tr.GetTrieStats(address, rootHash) if err != nil { log.Error(err.Error()) return } - stats.Add(trieStats) + stats.Add(trieStats, trieType) log.Debug(strings.Join(trieStats.ToString(), " ")) } diff --git a/state/accountsDBApi.go b/state/accountsDBApi.go index 6b9f733e459..c91fecdda64 100644 --- a/state/accountsDBApi.go +++ b/state/accountsDBApi.go @@ -196,13 +196,13 @@ func (accountsDB *accountsDBApi) IsPruningEnabled() bool { } // GetAllLeaves will call the inner accountsAdapter method after trying to recreate the trie -func (accountsDB *accountsDBApi) GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { +func (accountsDB *accountsDBApi) GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error { _, err := accountsDB.recreateTrieIfNecessary() if err != nil { return err } - return accountsDB.innerAccountsAdapter.GetAllLeaves(leavesChannels, ctx, rootHash) + return accountsDB.innerAccountsAdapter.GetAllLeaves(leavesChannels, ctx, rootHash, trieLeafParser) } // RecreateAllTries is a not permitted operation in this implementation and thus, will return an error diff --git a/state/accountsDBApiWithHistory.go b/state/accountsDBApiWithHistory.go index 08eb14b8378..0f95117eba9 100644 --- a/state/accountsDBApiWithHistory.go +++ b/state/accountsDBApiWithHistory.go @@ -125,7 +125,7 @@ func (accountsDB *accountsDBApiWithHistory) IsPruningEnabled() bool { } // GetAllLeaves will return an error -func (accountsDB *accountsDBApiWithHistory) GetAllLeaves(_ *common.TrieIteratorChannels, _ context.Context, _ []byte) error { +func (accountsDB *accountsDBApiWithHistory) GetAllLeaves(_ *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { return ErrOperationNotPermitted } diff --git a/state/accountsDBApiWithHistory_test.go b/state/accountsDBApiWithHistory_test.go index 8e1b9a5d4e6..5bcd1b20c6a 100644 --- a/state/accountsDBApiWithHistory_test.go +++ b/state/accountsDBApiWithHistory_test.go @@ -84,7 +84,7 @@ func TestAccountsDBApiWithHistory_NotPermittedOrNotImplementedOperationsDoNotPan accountsApi.SetStateCheckpoint(nil) assert.Equal(t, false, accountsApi.IsPruningEnabled()) - assert.Equal(t, state.ErrOperationNotPermitted, accountsApi.GetAllLeaves(&common.TrieIteratorChannels{}, nil, nil)) + assert.Equal(t, state.ErrOperationNotPermitted, accountsApi.GetAllLeaves(&common.TrieIteratorChannels{}, nil, nil, nil)) resultedMap, err := accountsApi.RecreateAllTries(nil) assert.Nil(t, resultedMap) @@ -129,7 +129,7 @@ func TestAccountsDBApiWithHistory_GetAccountWithBlockInfo(t *testing.T) { }, GetExistingAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { if bytes.Equal(address, testscommon.TestPubKeyAlice) { - return state.NewUserAccount(address) + return createUserAcc(address), nil } return nil, errors.New("not found") @@ -307,7 +307,9 @@ func TestAccountsDBApiWithHistory_GetAccountWithBlockInfoWhenHighConcurrency(t * } func createDummyAccountWithBalanceString(balanceString string) state.UserAccountHandler { - dummyAccount := state.NewEmptyUserAccount() + dummyAccount := &mockState.AccountWrapMock{ + Balance: big.NewInt(0), + } dummyBalance, _ := big.NewInt(0).SetString(balanceString, 10) _ = dummyAccount.AddToBalance(dummyBalance) diff --git a/state/accountsDBApi_test.go b/state/accountsDBApi_test.go index c0bba62df79..f2d29cf3ce2 100644 --- a/state/accountsDBApi_test.go +++ b/state/accountsDBApi_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/holders" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/testscommon" mockState "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/trie" @@ -308,7 +309,7 @@ func TestAccountsDBApi_GetExistingAccount(t *testing.T) { return nil }, GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount(addressContainer) + return createUserAcc(addressContainer), nil }, } @@ -352,7 +353,7 @@ func TestAccountsDBApi_GetAccountFromBytes(t *testing.T) { return nil }, GetAccountFromBytesCalled: func(address []byte, accountBytes []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount(address) + return createUserAcc(address), nil }, } @@ -396,7 +397,7 @@ func TestAccountsDBApi_LoadAccount(t *testing.T) { return nil }, LoadAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount(address) + return createUserAcc(address), nil }, } @@ -462,14 +463,14 @@ func TestAccountsDBApi_GetAllLeaves(t *testing.T) { RecreateTrieCalled: func(rootHash []byte) error { return expectedErr }, - GetAllLeavesCalled: func(_ *common.TrieIteratorChannels, _ context.Context, _ []byte) error { + GetAllLeavesCalled: func(_ *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.TrieLeafParser) error { require.Fail(t, "should have not called inner method") return nil }, } accountsApi, _ := state.NewAccountsDBApi(accountsAdapter, createBlockInfoProviderStub(dummyRootHash)) - err := accountsApi.GetAllLeaves(&common.TrieIteratorChannels{}, nil, []byte{}) + err := accountsApi.GetAllLeaves(&common.TrieIteratorChannels{}, nil, []byte{}, parsers.NewMainTrieLeafParser()) assert.Equal(t, expectedErr, err) }) t.Run("recreate trie works, should call inner method", func(t *testing.T) { @@ -485,7 +486,7 @@ func TestAccountsDBApi_GetAllLeaves(t *testing.T) { } accountsApi, _ := state.NewAccountsDBApi(accountsAdapter, createBlockInfoProviderStub(dummyRootHash)) - err := accountsApi.GetAllLeaves(providedChan, context.Background(), []byte("address")) + err := accountsApi.GetAllLeaves(providedChan, context.Background(), []byte("address"), parsers.NewMainTrieLeafParser()) assert.Nil(t, err) assert.True(t, recreateTrieCalled) }) @@ -562,7 +563,9 @@ func TestAccountsDBApi_GetAccountWithBlockInfoWhenHighConcurrency(t *testing.T) } func createDummyAccountWithBalanceBytes(balanceBytes []byte) state.UserAccountHandler { - dummyAccount := state.NewEmptyUserAccount() + dummyAccount := &mockState.AccountWrapMock{ + Balance: big.NewInt(0), + } dummyBalance := big.NewInt(0).SetBytes(balanceBytes) _ = dummyAccount.AddToBalance(dummyBalance) diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index ae56e67ccfd..daf22075d7e 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/keyValStorage" + "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" @@ -24,11 +25,14 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/factory" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/state/storagePruningManager" "github.com/multiversx/mx-chain-go/state/storagePruningManager/disabled" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -51,9 +55,9 @@ func createMockAccountsDBArgs() state.ArgsAccountsDB { }, }, Hasher: &hashingMocks.HasherMock{}, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, AccountFactory: &stateMock.AccountsFactoryStub{ - CreateAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { + CreateAccountCalled: func(address []byte, _ hashing.Hasher, _ marshal.Marshalizer) (vmcommon.AccountHandler, error) { return stateMock.NewAccountWrapMock(address), nil }, }, @@ -65,6 +69,19 @@ func createMockAccountsDBArgs() state.ArgsAccountsDB { } } +func getDefaultArgsAccountCreation() state.ArgsAccountCreation { + return state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } +} + +func createUserAcc(address []byte) state.UserAccountHandler { + acc, _ := state.NewUserAccount(address, getDefaultArgsAccountCreation()) + return acc +} + func generateAccountDBFromTrie(trie common.Trie) *state.AccountsDB { args := createMockAccountsDBArgs() args.Trie = trie @@ -107,26 +124,32 @@ func getDefaultStateComponents( SnapshotsBufferLen: 10, SnapshotsGoroutineNum: 1, } - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} args := storage.GetStorageManagerArgs() args.MainStorer = db args.CheckpointHashesHolder = hashesHolder trieStorage, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorage, marshaller, hasher, 5) + tr, _ := trie.NewTrie(trieStorage, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, HashesSize: 10000, } ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, generalCfg.PruningBufferLen) + argsAccCreator := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshaller, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + accCreator, _ := factory.NewAccountCreator(argsAccCreator) argsAccountsDB := state.ArgsAccountsDB{ Trie: tr, Hasher: hasher, Marshaller: marshaller, - AccountFactory: factory.NewAccountCreator(), + AccountFactory: accCreator, StoragePruningManager: spm, ProcessingMode: common.Normal, ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, @@ -269,7 +292,7 @@ func TestAccountsDB_SaveAccountNilOldAccount(t *testing.T) { }, }) - acc, _ := state.NewUserAccount([]byte("someAddress")) + acc := createUserAcc([]byte("someAddress")) err := adb.SaveAccount(acc) assert.Nil(t, err) assert.Equal(t, 1, adb.JournalLen()) @@ -278,11 +301,10 @@ func TestAccountsDB_SaveAccountNilOldAccount(t *testing.T) { func TestAccountsDB_SaveAccountExistingOldAccount(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount([]byte("someAddress")) - + acc := createUserAcc([]byte("someAddress")) adb := generateAccountDBFromTrie(&trieMock.TrieStub{ GetCalled: func(_ []byte) ([]byte, uint32, error) { - serializedAcc, err := (&testscommon.MarshalizerMock{}).Marshal(acc) + serializedAcc, err := (&marshallerMock.MarshalizerMock{}).Marshal(acc) return serializedAcc, 0, err }, UpdateCalled: func(key, value []byte) error { @@ -306,7 +328,7 @@ func TestAccountsDB_SaveAccountSavesCodeAndDataTrieForUserAccount(t *testing.T) GetCalled: func(_ []byte) ([]byte, uint32, error) { return nil, 0, nil }, - UpdateCalled: func(key, value []byte) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { return nil }, RootCalled: func() (i []byte, err error) { @@ -331,7 +353,7 @@ func TestAccountsDB_SaveAccountSavesCodeAndDataTrieForUserAccount(t *testing.T) }) accCode := []byte("code") - acc, _ := state.NewUserAccount([]byte("someAddress")) + acc := createUserAcc([]byte("someAddress")) acc.SetCode(accCode) _ = acc.SaveKeyValue([]byte("key"), []byte("value")) @@ -351,7 +373,7 @@ func TestAccountsDB_SaveAccountMalfunctionMarshallerShouldErr(t *testing.T) { return &storageManager.StorageManagerStub{} }, } - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} args := createMockAccountsDBArgs() args.Marshaller = marshaller args.Trie = mockTrie @@ -393,13 +415,13 @@ func TestAccountsDB_RemoveAccountShouldWork(t *testing.T) { t.Parallel() wasCalled := false - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} trieStub := &trieMock.TrieStub{ GetCalled: func(_ []byte) ([]byte, uint32, error) { serializedAcc, err := marshaller.Marshal(stateMock.AccountWrapMock{}) return serializedAcc, 0, err }, - UpdateCalled: func(key, value []byte) error { + DeleteCalled: func(key []byte) error { wasCalled = true return nil }, @@ -468,7 +490,7 @@ func TestAccountsDB_LoadAccountExistingShouldLoadDataTrie(t *testing.T) { acc.SetCodeHash(codeHash) code := []byte("code") dataTrie := &trieMock.TrieStub{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} trieStub := &trieMock.TrieStub{ GetCalled: func(key []byte) ([]byte, uint32, error) { @@ -546,7 +568,7 @@ func TestAccountsDB_GetExistingAccountFoundShouldRetAccount(t *testing.T) { acc.SetCodeHash(codeHash) code := []byte("code") dataTrie := &trieMock.TrieStub{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} trieStub := &trieMock.TrieStub{ GetCalled: func(key []byte) ([]byte, uint32, error) { @@ -593,7 +615,7 @@ func TestAccountsDB_GetAccountAccountNotFound(t *testing.T) { testAccount.MockValue = 45 // Step 2. marshalize the DbAccount - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} buff, err := marshaller.Marshal(testAccount) assert.Nil(t, err) @@ -662,7 +684,7 @@ func TestAccountsDB_LoadCodeOkValsShouldWork(t *testing.T) { }, } adr, account, _ := generateAddressAccountAccountsDB(tr) - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} trieStub := &trieMock.TrieStub{ GetCalled: func(_ []byte) ([]byte, uint32, error) { @@ -814,7 +836,7 @@ func TestAccountsDB_CommitShouldCallCommitFromTrie(t *testing.T) { t.Parallel() commitCalled := 0 - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} serializedAccount, _ := marshaller.Marshal(stateMock.AccountWrapMock{}) trieStub := trieMock.TrieStub{ CommitCalled: func() error { @@ -833,7 +855,7 @@ func TestAccountsDB_CommitShouldCallCommitFromTrie(t *testing.T) { GetCalled: func(_ []byte) ([]byte, uint32, error) { return []byte("doge"), 0, nil }, - UpdateCalled: func(key, value []byte) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { return nil }, CommitCalled: func() error { @@ -1422,7 +1444,7 @@ func TestAccountsDB_GetAllLeaves(t *testing.T) { getAllLeavesCalled := false trieStub := &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, builder common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, builder common.KeyBuilder, _ common.TrieLeafParser) error { getAllLeavesCalled = true close(channels.LeavesChan) channels.ErrChan.Close() @@ -1440,7 +1462,7 @@ func TestAccountsDB_GetAllLeaves(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err := adb.GetAllLeaves(leavesChannel, context.Background(), []byte("root hash")) + err := adb.GetAllLeaves(leavesChannel, context.Background(), []byte("root hash"), parsers.NewMainTrieLeafParser()) assert.Nil(t, err) assert.True(t, getAllLeavesCalled) @@ -1471,7 +1493,7 @@ func checkCodeEntry( func TestAccountsDB_SaveAccountSavesCodeIfCodeHashIsSet(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} tr, adb := getDefaultTrieAndAccountsDb() @@ -1506,7 +1528,7 @@ func TestAccountsDB_saveCode_OldCodeAndNewCodeAreNil(t *testing.T) { func TestAccountsDB_saveCode_OldCodeIsNilAndNewCodeIsNotNilAndRevert(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} tr, adb := getDefaultTrieAndAccountsDb() @@ -1535,7 +1557,7 @@ func TestAccountsDB_saveCode_OldCodeIsNilAndNewCodeIsNotNilAndRevert(t *testing. func TestAccountsDB_saveCode_OldCodeIsNilAndNewCodeAlreadyExistsAndRevert(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} tr, adb := getDefaultTrieAndAccountsDb() @@ -1571,7 +1593,7 @@ func TestAccountsDB_saveCode_OldCodeIsNilAndNewCodeAlreadyExistsAndRevert(t *tes func TestAccountsDB_saveCode_OldCodeExistsAndNewCodeIsNilAndRevert(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} tr, adb := getDefaultTrieAndAccountsDb() @@ -1607,7 +1629,7 @@ func TestAccountsDB_saveCode_OldCodeExistsAndNewCodeIsNilAndRevert(t *testing.T) func TestAccountsDB_saveCode_OldCodeExistsAndNewCodeExistsAndRevert(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} tr, adb := getDefaultTrieAndAccountsDb() @@ -1653,7 +1675,7 @@ func TestAccountsDB_saveCode_OldCodeExistsAndNewCodeExistsAndRevert(t *testing.T func TestAccountsDB_saveCode_OldCodeIsReferencedMultipleTimesAndNewCodeIsNilAndRevert(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} tr, adb := getDefaultTrieAndAccountsDb() @@ -1728,20 +1750,25 @@ func TestAccountsDB_RemoveAccountAlsoRemovesCodeAndRevertsCorrectly(t *testing.T func TestAccountsDB_MainTrieAutomaticallyMarksCodeUpdatesForEviction(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} ewl := stateMock.NewEvictionWaitingListMock(100) args := storage.GetStorageManagerArgs() storageManager, _ := trie.NewTrieStorageManager(args) maxTrieLevelInMemory := uint(5) - tr, _ := trie.NewTrie(storageManager, marshaller, hasher, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(storageManager, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 5) argsAccountsDB := createMockAccountsDBArgs() argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccountsDB.AccountFactory = factory.NewAccountCreator() + argsAccCreator := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshaller, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -1804,20 +1831,25 @@ func TestAccountsDB_RemoveAccountMarksObsoleteHashesForEviction(t *testing.T) { t.Parallel() maxTrieLevelInMemory := uint(5) - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} ewl := stateMock.NewEvictionWaitingListMock(100) args := storage.GetStorageManagerArgs() storageManager, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(storageManager, marshaller, hasher, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(storageManager, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 5) argsAccountsDB := createMockAccountsDBArgs() argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccountsDB.AccountFactory = factory.NewAccountCreator() + argsAccCreator := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshaller, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -2228,18 +2260,23 @@ func TestAccountsDB_GetCode(t *testing.T) { t.Parallel() maxTrieLevelInMemory := uint(5) - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} args := storage.GetStorageManagerArgs() storageManager, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(storageManager, marshaller, hasher, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(storageManager, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) spm := disabled.NewDisabledStoragePruningManager() argsAccountsDB := createMockAccountsDBArgs() argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccountsDB.AccountFactory = factory.NewAccountCreator() + argsAccCreator := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshaller, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -2297,7 +2334,7 @@ func TestAccountsDB_RecreateAllTries(t *testing.T) { expectedErr := errors.New("expected error") args.Trie = &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, _ common.TrieLeafParser) error { go func() { leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("key"), []byte("val")) leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) @@ -2326,7 +2363,7 @@ func TestAccountsDB_RecreateAllTries(t *testing.T) { args := createMockAccountsDBArgs() args.Trie = &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, _ common.TrieLeafParser) error { go func() { leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("key"), []byte("val")) @@ -2382,7 +2419,7 @@ func TestAccountsDB_Close(t *testing.T) { return &storageManager.StorageManagerStub{} }, } - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, @@ -2395,7 +2432,12 @@ func TestAccountsDB_Close(t *testing.T) { argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccountsDB.AccountFactory = factory.NewAccountCreator() + argsAccCreator := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshaller, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -2419,9 +2461,9 @@ func TestAccountsDB_GetAccountFromBytesInvalidAddress(t *testing.T) { func TestAccountsDB_GetAccountFromBytes(t *testing.T) { t.Parallel() - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} adr := make([]byte, 32) - accountExpected, _ := state.NewUserAccount(adr) + accountExpected := createUserAcc(adr) accountBytes, _ := marshaller.Marshal(accountExpected) _, adb := getDefaultTrieAndAccountsDb() @@ -2436,7 +2478,7 @@ func TestAccountsDB_GetAccountFromBytesShouldLoadDataTrie(t *testing.T) { acc := generateAccount() acc.SetRootHash([]byte("root hash")) dataTrie := &trieMock.TrieStub{} - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} serializerAcc, _ := marshaller.Marshal(acc) trieStub := &trieMock.TrieStub{ @@ -2610,19 +2652,24 @@ func TestAccountsDB_NewAccountsDbStartsSnapshotAfterRestart(t *testing.T) { func BenchmarkAccountsDb_GetCodeEntry(b *testing.B) { maxTrieLevelInMemory := uint(5) - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} args := storage.GetStorageManagerArgs() storageManager, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(storageManager, marshaller, hasher, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(storageManager, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) spm := disabled.NewDisabledStoragePruningManager() argsAccountsDB := createMockAccountsDBArgs() argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccountsDB.AccountFactory = factory.NewAccountCreator() + argsAccCreator := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshaller, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -2971,7 +3018,7 @@ func testAccountMethodsConcurrency( rootHash []byte, ) { numOperations := 100 - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} wg := sync.WaitGroup{} wg.Add(numOperations) @@ -3024,7 +3071,7 @@ func testAccountMethodsConcurrency( case 17: _ = adb.IsPruningEnabled() case 18: - _ = adb.GetAllLeaves(&common.TrieIteratorChannels{}, context.Background(), rootHash) + _ = adb.GetAllLeaves(&common.TrieIteratorChannels{}, context.Background(), rootHash, parsers.NewMainTrieLeafParser()) case 19: _, _ = adb.RecreateAllTries(rootHash) case 20: diff --git a/state/baseAccount.go b/state/baseAccount.go index 07c1d7c6778..e5e94367ca8 100644 --- a/state/baseAccount.go +++ b/state/baseAccount.go @@ -1,6 +1,7 @@ package state import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -53,7 +54,7 @@ func (ba *baseAccount) SaveKeyValue(key []byte, value []byte) error { } // SaveDirtyData triggers SaveDirtyData form the underlying trackableDataTrie -func (ba *baseAccount) SaveDirtyData(trie common.Trie) (map[string][]byte, error) { +func (ba *baseAccount) SaveDirtyData(trie common.Trie) ([]core.TrieData, error) { if check.IfNil(ba.dataTrieTracker) { return nil, ErrNilTrackableDataTrie } diff --git a/state/dataTrieValue/dataTrieValue.go b/state/dataTrieValue/dataTrieValue.go new file mode 100644 index 00000000000..4f09be3510e --- /dev/null +++ b/state/dataTrieValue/dataTrieValue.go @@ -0,0 +1,2 @@ +//go:generate protoc -I=. -I=$GOPATH/src -I=$GOPATH/src/github.com/ElrondNetwork/protobuf/protobuf --gogoslick_out=. dataTrieValue.proto +package dataTrieValue diff --git a/state/dataTrieValue/dataTrieValue.pb.go b/state/dataTrieValue/dataTrieValue.pb.go new file mode 100644 index 00000000000..e2388e1f75b --- /dev/null +++ b/state/dataTrieValue/dataTrieValue.pb.go @@ -0,0 +1,499 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: dataTrieValue.proto + +package dataTrieValue + +import ( + bytes "bytes" + fmt "fmt" + _ "github.com/gogo/protobuf/gogoproto" + proto "github.com/gogo/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" + reflect "reflect" + strings "strings" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type TrieLeafData struct { + Value []byte `protobuf:"bytes,1,opt,name=Value,proto3" json:"value"` + Key []byte `protobuf:"bytes,2,opt,name=Key,proto3" json:"key"` + Address []byte `protobuf:"bytes,3,opt,name=Address,proto3" json:"address"` +} + +func (m *TrieLeafData) Reset() { *m = TrieLeafData{} } +func (*TrieLeafData) ProtoMessage() {} +func (*TrieLeafData) Descriptor() ([]byte, []int) { + return fileDescriptor_a7eb1d726875d08d, []int{0} +} +func (m *TrieLeafData) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *TrieLeafData) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil +} +func (m *TrieLeafData) XXX_Merge(src proto.Message) { + xxx_messageInfo_TrieLeafData.Merge(m, src) +} +func (m *TrieLeafData) XXX_Size() int { + return m.Size() +} +func (m *TrieLeafData) XXX_DiscardUnknown() { + xxx_messageInfo_TrieLeafData.DiscardUnknown(m) +} + +var xxx_messageInfo_TrieLeafData proto.InternalMessageInfo + +func (m *TrieLeafData) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +func (m *TrieLeafData) GetKey() []byte { + if m != nil { + return m.Key + } + return nil +} + +func (m *TrieLeafData) GetAddress() []byte { + if m != nil { + return m.Address + } + return nil +} + +func init() { + proto.RegisterType((*TrieLeafData)(nil), "proto.TrieLeafData") +} + +func init() { proto.RegisterFile("dataTrieValue.proto", fileDescriptor_a7eb1d726875d08d) } + +var fileDescriptor_a7eb1d726875d08d = []byte{ + // 236 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x4e, 0x49, 0x2c, 0x49, + 0x0c, 0x29, 0xca, 0x4c, 0x0d, 0x4b, 0xcc, 0x29, 0x4d, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, + 0x62, 0x05, 0x53, 0x52, 0xba, 0xe9, 0x99, 0x25, 0x19, 0xa5, 0x49, 0x7a, 0xc9, 0xf9, 0xb9, 0xfa, + 0xe9, 0xf9, 0xe9, 0xf9, 0xfa, 0x60, 0xe1, 0xa4, 0xd2, 0x34, 0x30, 0x0f, 0xcc, 0x01, 0xb3, 0x20, + 0xba, 0x94, 0x0a, 0xb9, 0x78, 0x40, 0x06, 0xf9, 0xa4, 0x26, 0xa6, 0xb9, 0x24, 0x96, 0x24, 0x0a, + 0xc9, 0x73, 0xb1, 0x82, 0x0d, 0x95, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x71, 0xe2, 0x7c, 0x75, 0x4f, + 0x9e, 0xb5, 0x0c, 0x24, 0x10, 0x04, 0x11, 0x17, 0x92, 0xe4, 0x62, 0xf6, 0x4e, 0xad, 0x94, 0x60, + 0x02, 0x4b, 0xb3, 0xbf, 0xba, 0x27, 0xcf, 0x9c, 0x9d, 0x5a, 0x19, 0x04, 0x12, 0x13, 0x52, 0xe5, + 0x62, 0x77, 0x4c, 0x49, 0x29, 0x4a, 0x2d, 0x2e, 0x96, 0x60, 0x06, 0x4b, 0x73, 0xbf, 0xba, 0x27, + 0xcf, 0x9e, 0x08, 0x11, 0x0a, 0x82, 0xc9, 0x39, 0xb9, 0x5f, 0x78, 0x28, 0xc7, 0x70, 0xe3, 0xa1, + 0x1c, 0xc3, 0x87, 0x87, 0x72, 0x8c, 0x0d, 0x8f, 0xe4, 0x18, 0x57, 0x3c, 0x92, 0x63, 0x3c, 0xf1, + 0x48, 0x8e, 0xf1, 0xc2, 0x23, 0x39, 0xc6, 0x1b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x7c, + 0xf1, 0x48, 0x8e, 0xe1, 0xc3, 0x23, 0x39, 0xc6, 0x09, 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, + 0xb8, 0xf1, 0x58, 0x8e, 0x21, 0x8a, 0x17, 0xc5, 0xdf, 0x49, 0x6c, 0x60, 0x2f, 0x18, 0x03, 0x02, + 0x00, 0x00, 0xff, 0xff, 0x56, 0xb1, 0x0b, 0x65, 0x0f, 0x01, 0x00, 0x00, +} + +func (this *TrieLeafData) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*TrieLeafData) + if !ok { + that2, ok := that.(TrieLeafData) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if !bytes.Equal(this.Value, that1.Value) { + return false + } + if !bytes.Equal(this.Key, that1.Key) { + return false + } + if !bytes.Equal(this.Address, that1.Address) { + return false + } + return true +} +func (this *TrieLeafData) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 7) + s = append(s, "&dataTrieValue.TrieLeafData{") + s = append(s, "Value: "+fmt.Sprintf("%#v", this.Value)+",\n") + s = append(s, "Key: "+fmt.Sprintf("%#v", this.Key)+",\n") + s = append(s, "Address: "+fmt.Sprintf("%#v", this.Address)+",\n") + s = append(s, "}") + return strings.Join(s, "") +} +func valueToGoStringDataTrieValue(v interface{}, typ string) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv) +} +func (m *TrieLeafData) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *TrieLeafData) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *TrieLeafData) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Address) > 0 { + i -= len(m.Address) + copy(dAtA[i:], m.Address) + i = encodeVarintDataTrieValue(dAtA, i, uint64(len(m.Address))) + i-- + dAtA[i] = 0x1a + } + if len(m.Key) > 0 { + i -= len(m.Key) + copy(dAtA[i:], m.Key) + i = encodeVarintDataTrieValue(dAtA, i, uint64(len(m.Key))) + i-- + dAtA[i] = 0x12 + } + if len(m.Value) > 0 { + i -= len(m.Value) + copy(dAtA[i:], m.Value) + i = encodeVarintDataTrieValue(dAtA, i, uint64(len(m.Value))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarintDataTrieValue(dAtA []byte, offset int, v uint64) int { + offset -= sovDataTrieValue(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *TrieLeafData) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Value) + if l > 0 { + n += 1 + l + sovDataTrieValue(uint64(l)) + } + l = len(m.Key) + if l > 0 { + n += 1 + l + sovDataTrieValue(uint64(l)) + } + l = len(m.Address) + if l > 0 { + n += 1 + l + sovDataTrieValue(uint64(l)) + } + return n +} + +func sovDataTrieValue(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozDataTrieValue(x uint64) (n int) { + return sovDataTrieValue(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (this *TrieLeafData) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&TrieLeafData{`, + `Value:` + fmt.Sprintf("%v", this.Value) + `,`, + `Key:` + fmt.Sprintf("%v", this.Key) + `,`, + `Address:` + fmt.Sprintf("%v", this.Address) + `,`, + `}`, + }, "") + return s +} +func valueToStringDataTrieValue(v interface{}) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("*%v", pv) +} +func (m *TrieLeafData) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDataTrieValue + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: TrieLeafData: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: TrieLeafData: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Value", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDataTrieValue + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthDataTrieValue + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthDataTrieValue + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Value = append(m.Value[:0], dAtA[iNdEx:postIndex]...) + if m.Value == nil { + m.Value = []byte{} + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDataTrieValue + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthDataTrieValue + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthDataTrieValue + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Key = append(m.Key[:0], dAtA[iNdEx:postIndex]...) + if m.Key == nil { + m.Key = []byte{} + } + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Address", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDataTrieValue + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthDataTrieValue + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthDataTrieValue + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Address = append(m.Address[:0], dAtA[iNdEx:postIndex]...) + if m.Address == nil { + m.Address = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipDataTrieValue(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthDataTrieValue + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthDataTrieValue + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipDataTrieValue(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowDataTrieValue + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowDataTrieValue + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowDataTrieValue + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthDataTrieValue + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupDataTrieValue + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthDataTrieValue + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthDataTrieValue = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowDataTrieValue = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupDataTrieValue = fmt.Errorf("proto: unexpected end of group") +) diff --git a/state/dataTrieValue/dataTrieValue.proto b/state/dataTrieValue/dataTrieValue.proto new file mode 100644 index 00000000000..bac376cf2ed --- /dev/null +++ b/state/dataTrieValue/dataTrieValue.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package proto; + +option go_package = "dataTrieValue"; +option (gogoproto.stable_marshaler_all) = true; + +import "github.com/gogo/protobuf/gogoproto/gogo.proto"; + +message TrieLeafData { + bytes Value = 1 [(gogoproto.jsontag) = "value"]; + bytes Key = 2 [(gogoproto.jsontag) = "key"]; + bytes Address = 3 [(gogoproto.jsontag) = "address"]; +} diff --git a/state/disabled/disabledDataTrieHandler.go b/state/disabled/disabledDataTrieHandler.go new file mode 100644 index 00000000000..10f2d25586b --- /dev/null +++ b/state/disabled/disabledDataTrieHandler.go @@ -0,0 +1,48 @@ +package disabled + +import ( + "context" + + "github.com/multiversx/mx-chain-go/common" +) + +type disabledDataTrieHandler struct { +} + +// NewDisabledDataTrieHandler returns a new instance of disabledDataTrieHandler +func NewDisabledDataTrieHandler() *disabledDataTrieHandler { + return &disabledDataTrieHandler{} +} + +// RootHash returns an empty byte array +func (ddth *disabledDataTrieHandler) RootHash() ([]byte, error) { + return []byte{}, nil +} + +// GetAllLeavesOnChannel does nothing for this implementation +func (ddth *disabledDataTrieHandler) GetAllLeavesOnChannel( + leavesChannels *common.TrieIteratorChannels, + _ context.Context, + _ []byte, + _ common.KeyBuilder, + _ common.TrieLeafParser, +) error { + if leavesChannels.LeavesChan != nil { + close(leavesChannels.LeavesChan) + } + if leavesChannels.ErrChan != nil { + leavesChannels.ErrChan.Close() + } + + return nil +} + +// IsMigratedToLatestVersion returns true +func (ddth *disabledDataTrieHandler) IsMigratedToLatestVersion() (bool, error) { + return true, nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (ddth *disabledDataTrieHandler) IsInterfaceNil() bool { + return ddth == nil +} diff --git a/state/disabled/disabledDataTrieHandler_test.go b/state/disabled/disabledDataTrieHandler_test.go new file mode 100644 index 00000000000..be20268f39c --- /dev/null +++ b/state/disabled/disabledDataTrieHandler_test.go @@ -0,0 +1,49 @@ +package disabled + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" + "github.com/stretchr/testify/assert" +) + +func TestNewDisabledDataTrieHandler(t *testing.T) { + t.Parallel() + + t.Run("new disabledDataTrieHandler", func(t *testing.T) { + t.Parallel() + + assert.False(t, check.IfNil(NewDisabledDataTrieHandler())) + }) + + t.Run("root hash", func(t *testing.T) { + t.Parallel() + + ddth := NewDisabledDataTrieHandler() + + rootHash, err := ddth.RootHash() + assert.Equal(t, 0, len(rootHash)) + assert.Nil(t, err) + }) + + t.Run("get all leaves on channel", func(t *testing.T) { + t.Parallel() + + ddth := NewDisabledDataTrieHandler() + + chans := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, 1), + ErrChan: errChan.NewErrChanWrapper(), + } + + err := ddth.GetAllLeavesOnChannel(chans, nil, nil, nil, nil) + assert.Nil(t, err) + _, ok := <-chans.LeavesChan + assert.False(t, ok) + err = chans.ErrChan.ReadFromChanNonBlocking() + assert.Nil(t, err) + }) +} diff --git a/state/disabled/disabledTrackableDataTrie.go b/state/disabled/disabledTrackableDataTrie.go new file mode 100644 index 00000000000..3156862677f --- /dev/null +++ b/state/disabled/disabledTrackableDataTrie.go @@ -0,0 +1,49 @@ +package disabled + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/common" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" +) + +type disabledTrackableDataTrie struct { +} + +// NewDisabledTrackableDataTrie returns a new instance of disabledTrackableDataTrie +func NewDisabledTrackableDataTrie() *disabledTrackableDataTrie { + return &disabledTrackableDataTrie{} +} + +// RetrieveValue returns an empty byte array +func (dtdt *disabledTrackableDataTrie) RetrieveValue(_ []byte) ([]byte, uint32, error) { + return []byte{}, 0, nil +} + +// SaveKeyValue does nothing for this implementation +func (dtdt *disabledTrackableDataTrie) SaveKeyValue(_ []byte, _ []byte) error { + return nil +} + +// SetDataTrie does nothing for this implementation +func (dtdt *disabledTrackableDataTrie) SetDataTrie(_ common.Trie) { +} + +// DataTrie returns a new disabledDataTrieHandler +func (dtdt *disabledTrackableDataTrie) DataTrie() common.DataTrieHandler { + return NewDisabledDataTrieHandler() +} + +// SaveDirtyData does nothing for this implementation +func (dtdt *disabledTrackableDataTrie) SaveDirtyData(_ common.Trie) ([]core.TrieData, error) { + return make([]core.TrieData, 0), nil +} + +// MigrateDataTrieLeaves does nothing for this implementation +func (dtdt *disabledTrackableDataTrie) MigrateDataTrieLeaves(_ vmcommon.ArgsMigrateDataTrieLeaves) error { + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (dtdt *disabledTrackableDataTrie) IsInterfaceNil() bool { + return dtdt == nil +} diff --git a/state/disabled/disabledTrackableDataTrie_test.go b/state/disabled/disabledTrackableDataTrie_test.go new file mode 100644 index 00000000000..7b0c8a11c54 --- /dev/null +++ b/state/disabled/disabledTrackableDataTrie_test.go @@ -0,0 +1,61 @@ +package disabled + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" +) + +func TestNewDisabledTrackableDataTrie(t *testing.T) { + t.Parallel() + + assert.False(t, check.IfNil(NewDisabledTrackableDataTrie())) +} + +func TestDisabledTrackableDataTrie_RetrieveValue(t *testing.T) { + t.Parallel() + + dtdt := NewDisabledTrackableDataTrie() + + val, depth, err := dtdt.RetrieveValue(nil) + assert.Nil(t, err) + assert.Equal(t, uint32(0), depth) + assert.Equal(t, 0, len(val)) +} + +func TestDisabledTrackableDataTrie_SaveKeyValue(t *testing.T) { + t.Parallel() + + dtdt := NewDisabledTrackableDataTrie() + + err := dtdt.SaveKeyValue(nil, nil) + assert.Nil(t, err) +} + +func TestDisabledTrackableDataTrie_SetAndGetDataTrie(t *testing.T) { + t.Parallel() + + dtdt := NewDisabledTrackableDataTrie() + isDisabledDataTrieHandler := false + dtdt.SetDataTrie(nil) + tr := dtdt.DataTrie() + + switch tr.(type) { + case *disabledDataTrieHandler: + isDisabledDataTrieHandler = true + default: + assert.Fail(t, "this should not have been called") + } + assert.True(t, isDisabledDataTrieHandler) +} + +func TestDisabledTrackableDataTrie_SaveDirtyData(t *testing.T) { + t.Parallel() + + dtdt := NewDisabledTrackableDataTrie() + + oldValues, err := dtdt.SaveDirtyData(nil) + assert.Nil(t, err) + assert.Equal(t, 0, len(oldValues)) +} diff --git a/state/errors.go b/state/errors.go index 6973f061367..e0a1d94e7dc 100644 --- a/state/errors.go +++ b/state/errors.go @@ -147,3 +147,6 @@ var ErrNilSyncStatisticsHandler = errors.New("nil sync statistics handler") // ErrNilAddressConverter signals that a nil address converter was provided var ErrNilAddressConverter = errors.New("nil address converter") + +// ErrNilEnableEpochsHandler signals that a nil enable epochs handler has been provided +var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") diff --git a/state/export_test.go b/state/export_test.go index 3ff10d977b2..beb447ba355 100644 --- a/state/export_test.go +++ b/state/export_test.go @@ -1,6 +1,7 @@ package state import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -77,3 +78,28 @@ func (accountsDB *accountsDBApi) SetCurrentBlockInfo(blockInfo common.BlockInfo) func EmptyErrChanReturningHadContained(errChan chan error) bool { return emptyErrChanReturningHadContained(errChan) } + +// DirtyData - +type DirtyData struct { + Value []byte + NewVersion core.TrieNodeVersion +} + +// DirtyData - +func (tdaw *trackableDataTrie) DirtyData() map[string]DirtyData { + dd := make(map[string]DirtyData, len(tdaw.dirtyData)) + + for key, value := range tdaw.dirtyData { + dd[key] = DirtyData{ + Value: value.value, + NewVersion: value.newVersion, + } + } + + return dd +} + +// SaveDirtyData - +func (a *userAccount) SaveDirtyData(trie common.Trie) ([]core.TrieData, error) { + return a.dataTrieTracker.SaveDirtyData(trie) +} diff --git a/state/factory/accountCreator.go b/state/factory/accountCreator.go index 86d20821b6e..66ff4056ab9 100644 --- a/state/factory/accountCreator.go +++ b/state/factory/accountCreator.go @@ -1,22 +1,37 @@ package factory import ( + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) // AccountCreator has method to create a new account type AccountCreator struct { + accountArgs state.ArgsAccountCreation } -// NewAccountCreator creates an account creator -func NewAccountCreator() state.AccountFactory { - return &AccountCreator{} +// NewAccountCreator creates a new instance of AccountCreator +func NewAccountCreator(args state.ArgsAccountCreation) (state.AccountFactory, error) { + if check.IfNil(args.Hasher) { + return nil, errors.ErrNilHasher + } + if check.IfNil(args.Marshaller) { + return nil, errors.ErrNilMarshalizer + } + if check.IfNil(args.EnableEpochsHandler) { + return nil, errors.ErrNilEnableEpochsHandler + } + + return &AccountCreator{ + accountArgs: args, + }, nil } // CreateAccount calls the new Account creator and returns the result func (ac *AccountCreator) CreateAccount(address []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount(address) + return state.NewUserAccount(address, ac.accountArgs) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/state/factory/accountCreator_test.go b/state/factory/accountCreator_test.go index 1750716235c..e6dd28bea84 100644 --- a/state/factory/accountCreator_test.go +++ b/state/factory/accountCreator_test.go @@ -4,15 +4,66 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/factory" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/assert" ) +func getDefaultArgs() state.ArgsAccountCreation { + return state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } +} + +func TestNewAccountCreator(t *testing.T) { + t.Parallel() + + t.Run("nil hasher", func(t *testing.T) { + t.Parallel() + + args := getDefaultArgs() + args.Hasher = nil + accF, err := factory.NewAccountCreator(args) + assert.True(t, check.IfNil(accF)) + assert.Equal(t, errors.ErrNilHasher, err) + }) + t.Run("nil marshalizer", func(t *testing.T) { + t.Parallel() + + args := getDefaultArgs() + args.Marshaller = nil + accF, err := factory.NewAccountCreator(args) + assert.True(t, check.IfNil(accF)) + assert.Equal(t, errors.ErrNilMarshalizer, err) + }) + t.Run("nil enableEpochsHandler", func(t *testing.T) { + t.Parallel() + + args := getDefaultArgs() + args.EnableEpochsHandler = nil + accF, err := factory.NewAccountCreator(args) + assert.True(t, check.IfNil(accF)) + assert.Equal(t, errors.ErrNilEnableEpochsHandler, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + accF, err := factory.NewAccountCreator(getDefaultArgs()) + assert.False(t, check.IfNil(accF)) + assert.Nil(t, err) + }) +} + func TestAccountCreator_CreateAccountNilAddress(t *testing.T) { t.Parallel() - accF := factory.NewAccountCreator() + accF, _ := factory.NewAccountCreator(getDefaultArgs()) _, ok := accF.(*factory.AccountCreator) assert.Equal(t, true, ok) @@ -27,7 +78,7 @@ func TestAccountCreator_CreateAccountNilAddress(t *testing.T) { func TestAccountCreator_CreateAccountOk(t *testing.T) { t.Parallel() - accF := factory.NewAccountCreator() + accF, _ := factory.NewAccountCreator(getDefaultArgs()) _, ok := accF.(*factory.AccountCreator) assert.Equal(t, true, ok) diff --git a/state/factory/accountsAdapterAPICreator_test.go b/state/factory/accountsAdapterAPICreator_test.go index b0151b907c6..c6c579985c1 100644 --- a/state/factory/accountsAdapterAPICreator_test.go +++ b/state/factory/accountsAdapterAPICreator_test.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" mockState "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storageManager" @@ -23,7 +24,7 @@ func createMockAccountsArgs() state.ArgsAccountsDB { }, }, Hasher: &testscommon.HasherStub{}, - Marshaller: &testscommon.MarshalizerMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, AccountFactory: &mockState.AccountsFactoryStub{}, StoragePruningManager: &mockState.StoragePruningManagerStub{}, ProcessingMode: 0, diff --git a/state/interface.go b/state/interface.go index 8071418796c..0445ead10e1 100644 --- a/state/interface.go +++ b/state/interface.go @@ -4,6 +4,7 @@ import ( "context" "math/big" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -86,6 +87,7 @@ type UserAccountHandler interface { SetUserName(userName []byte) GetUserName() []byte IsGuarded() bool + GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context) error vmcommon.AccountHandler } @@ -95,7 +97,8 @@ type DataTrieTracker interface { SaveKeyValue(key []byte, value []byte) error SetDataTrie(tr common.Trie) DataTrie() common.DataTrieHandler - SaveDirtyData(common.Trie) (map[string][]byte, error) + SaveDirtyData(common.Trie) ([]core.TrieData, error) + MigrateDataTrieLeaves(args vmcommon.ArgsMigrateDataTrieLeaves) error IsInterfaceNil() bool } @@ -120,7 +123,7 @@ type AccountsAdapter interface { SnapshotState(rootHash []byte) SetStateCheckpoint(rootHash []byte) IsPruningEnabled() bool - GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error + GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error RecreateAllTries(rootHash []byte) (map[string]common.Trie, error) GetTrie(rootHash []byte) (common.Trie, error) GetStackDebugFirstEntry() []byte @@ -165,7 +168,7 @@ type baseAccountHandler interface { GetRootHash() []byte SetDataTrie(trie common.Trie) DataTrie() common.DataTrieHandler - SaveDirtyData(trie common.Trie) (map[string][]byte, error) + SaveDirtyData(trie common.Trie) ([]core.TrieData, error) IsInterfaceNil() bool } @@ -218,3 +221,10 @@ type AccountsAdapterAPI interface { GetAccountWithBlockInfo(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) GetCodeWithBlockInfo(codeHash []byte, options common.RootHashHolder) ([]byte, common.BlockInfo, error) } + +type dataTrie interface { + common.Trie + + UpdateWithVersion(key []byte, value []byte, version core.TrieNodeVersion) error + CollectLeavesForMigration(args vmcommon.ArgsMigrateDataTrieLeaves) error +} diff --git a/state/journalEntries.go b/state/journalEntries.go index 28e7fade9f6..a7f66fec8f3 100644 --- a/state/journalEntries.go +++ b/state/journalEntries.go @@ -4,9 +4,9 @@ import ( "bytes" "fmt" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/marshal" - "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) @@ -167,12 +167,12 @@ func (jea *journalEntryAccountCreation) IsInterfaceNil() bool { // JournalEntryDataTrieUpdates stores all the updates done to the account's data trie, // so it can be reverted in case of rollback type journalEntryDataTrieUpdates struct { - trieUpdates map[string][]byte + trieUpdates []core.TrieData account baseAccountHandler } // NewJournalEntryDataTrieUpdates outputs a new JournalEntryDataTrieUpdates implementation used to revert an account's data trie -func NewJournalEntryDataTrieUpdates(trieUpdates map[string][]byte, account baseAccountHandler) (*journalEntryDataTrieUpdates, error) { +func NewJournalEntryDataTrieUpdates(trieUpdates []core.TrieData, account baseAccountHandler) (*journalEntryDataTrieUpdates, error) { if check.IfNil(account) { return nil, fmt.Errorf("%w in NewJournalEntryDataTrieUpdates", ErrNilAccountHandler) } @@ -188,18 +188,22 @@ func NewJournalEntryDataTrieUpdates(trieUpdates map[string][]byte, account baseA // Revert applies undo operation func (jedtu *journalEntryDataTrieUpdates) Revert() (vmcommon.AccountHandler, error) { - trie, ok := jedtu.account.DataTrie().(common.Trie) + trie, ok := jedtu.account.DataTrie().(dataTrie) if !ok { return nil, fmt.Errorf("invalid trie, type is %T", jedtu.account.DataTrie()) } - for key := range jedtu.trieUpdates { - err := trie.Update([]byte(key), jedtu.trieUpdates[key]) + for _, trieUpdate := range jedtu.trieUpdates { + err := trie.UpdateWithVersion(trieUpdate.Key, trieUpdate.Value, trieUpdate.Version) if err != nil { return nil, err } - log.Trace("revert data trie update", "key", []byte(key), "val", jedtu.trieUpdates[key]) + log.Trace("revert data trie update", + "key", trieUpdate.Key, + "val", trieUpdate.Value, + "version", trieUpdate.Version, + ) } rootHash, err := trie.RootHash() diff --git a/state/journalEntries_test.go b/state/journalEntries_test.go index ae19ece1ad5..86eaa51d256 100644 --- a/state/journalEntries_test.go +++ b/state/journalEntries_test.go @@ -4,9 +4,12 @@ import ( "errors" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/stretchr/testify/assert" @@ -15,7 +18,7 @@ import ( func TestNewJournalEntryCode_NilUpdaterShouldErr(t *testing.T) { t.Parallel() - entry, err := state.NewJournalEntryCode(&state.CodeEntry{}, []byte("code hash"), []byte("code hash"), nil, &testscommon.MarshalizerMock{}) + entry, err := state.NewJournalEntryCode(&state.CodeEntry{}, []byte("code hash"), []byte("code hash"), nil, &marshallerMock.MarshalizerMock{}) assert.True(t, check.IfNil(entry)) assert.Equal(t, state.ErrNilUpdater, err) } @@ -31,7 +34,7 @@ func TestNewJournalEntryCode_NilMarshalizerShouldErr(t *testing.T) { func TestNewJournalEntryCode_OkParams(t *testing.T) { t.Parallel() - entry, err := state.NewJournalEntryCode(&state.CodeEntry{}, []byte("code hash"), []byte("code hash"), &trieMock.TrieStub{}, &testscommon.MarshalizerMock{}) + entry, err := state.NewJournalEntryCode(&state.CodeEntry{}, []byte("code hash"), []byte("code hash"), &trieMock.TrieStub{}, &marshallerMock.MarshalizerMock{}) assert.Nil(t, err) assert.False(t, check.IfNil(entry)) } @@ -40,7 +43,7 @@ func TestJournalEntryCode_OldHashAndNewHashAreNil(t *testing.T) { t.Parallel() trieStub := &trieMock.TrieStub{} - entry, _ := state.NewJournalEntryCode(&state.CodeEntry{}, nil, nil, trieStub, &testscommon.MarshalizerMock{}) + entry, _ := state.NewJournalEntryCode(&state.CodeEntry{}, nil, nil, trieStub, &marshallerMock.MarshalizerMock{}) acc, err := entry.Revert() assert.Nil(t, err) @@ -54,7 +57,7 @@ func TestJournalEntryCode_OldHashIsNilAndNewHashIsNotNil(t *testing.T) { Code: []byte("newCode"), NumReferences: 1, } - marshalizer := &testscommon.MarshalizerMock{} + marshalizer := &marshallerMock.MarshalizerMock{} updateCalled := false trieStub := &trieMock.TrieStub{ @@ -173,8 +176,12 @@ func TestJournalEntryAccountCreation_RevertUpdatesTheTrie(t *testing.T) { func TestNewJournalEntryDataTrieUpdates_NilAccountShouldErr(t *testing.T) { t.Parallel() - trieUpdates := make(map[string][]byte) - trieUpdates["a"] = []byte("b") + trieUpdates := make([]core.TrieData, 0) + trieUpdates = append(trieUpdates, core.TrieData{ + Key: []byte("a"), + Value: []byte("b"), + Version: 0, + }) entry, err := state.NewJournalEntryDataTrieUpdates(trieUpdates, nil) assert.True(t, check.IfNil(entry)) @@ -184,8 +191,13 @@ func TestNewJournalEntryDataTrieUpdates_NilAccountShouldErr(t *testing.T) { func TestNewJournalEntryDataTrieUpdates_EmptyTrieUpdatesShouldErr(t *testing.T) { t.Parallel() - trieUpdates := make(map[string][]byte) - accnt, _ := state.NewUserAccount(make([]byte, 32)) + trieUpdates := make([]core.TrieData, 0) + args := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + accnt, _ := state.NewUserAccount(make([]byte, 32), args) entry, err := state.NewJournalEntryDataTrieUpdates(trieUpdates, accnt) assert.True(t, check.IfNil(entry)) @@ -195,9 +207,18 @@ func TestNewJournalEntryDataTrieUpdates_EmptyTrieUpdatesShouldErr(t *testing.T) func TestNewJournalEntryDataTrieUpdates_OkValsShouldWork(t *testing.T) { t.Parallel() - trieUpdates := make(map[string][]byte) - trieUpdates["a"] = []byte("b") - accnt, _ := state.NewUserAccount(make([]byte, 32)) + trieUpdates := make([]core.TrieData, 0) + trieUpdates = append(trieUpdates, core.TrieData{ + Key: []byte("a"), + Value: []byte("b"), + Version: 0, + }) + args := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + accnt, _ := state.NewUserAccount(make([]byte, 32), args) entry, err := state.NewJournalEntryDataTrieUpdates(trieUpdates, accnt) assert.Nil(t, err) @@ -209,12 +230,16 @@ func TestJournalEntryDataTrieUpdates_RevertFailsWhenUpdateFails(t *testing.T) { expectedErr := errors.New("error") - trieUpdates := make(map[string][]byte) - trieUpdates["a"] = []byte("b") + trieUpdates := make([]core.TrieData, 0) + trieUpdates = append(trieUpdates, core.TrieData{ + Key: []byte("a"), + Value: []byte("b"), + Version: 0, + }) accnt := stateMock.NewAccountWrapMock(nil) tr := &trieMock.TrieStub{ - UpdateCalled: func(key, value []byte) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { return expectedErr }, } @@ -232,12 +257,16 @@ func TestJournalEntryDataTrieUpdates_RevertFailsWhenAccountRootFails(t *testing. expectedErr := errors.New("error") - trieUpdates := make(map[string][]byte) - trieUpdates["a"] = []byte("b") + trieUpdates := make([]core.TrieData, 0) + trieUpdates = append(trieUpdates, core.TrieData{ + Key: []byte("a"), + Value: []byte("b"), + Version: 0, + }) accnt := stateMock.NewAccountWrapMock(nil) tr := &trieMock.TrieStub{ - UpdateCalled: func(key, value []byte) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { return nil }, RootCalled: func() ([]byte, error) { @@ -259,12 +288,16 @@ func TestJournalEntryDataTrieUpdates_RevertShouldWork(t *testing.T) { updateWasCalled := false rootWasCalled := false - trieUpdates := make(map[string][]byte) - trieUpdates["a"] = []byte("b") + trieUpdates := make([]core.TrieData, 0) + trieUpdates = append(trieUpdates, core.TrieData{ + Key: []byte("a"), + Value: []byte("b"), + Version: 0, + }) accnt := stateMock.NewAccountWrapMock(nil) tr := &trieMock.TrieStub{ - UpdateCalled: func(key, value []byte) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { updateWasCalled = true return nil }, diff --git a/state/parsers/dataTrieLeafParser.go b/state/parsers/dataTrieLeafParser.go new file mode 100644 index 00000000000..6437fbb55b9 --- /dev/null +++ b/state/parsers/dataTrieLeafParser.go @@ -0,0 +1,59 @@ +package parsers + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/core/keyValStorage" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/state/dataTrieValue" +) + +type dataTrieLeafParser struct { + address []byte + marshaller marshal.Marshalizer + enableEpochsHandler common.EnableEpochsHandler +} + +// NewDataTrieLeafParser returns a new instance of dataTrieLeafParser +func NewDataTrieLeafParser(address []byte, marshaller marshal.Marshalizer, enableEpochsHandler common.EnableEpochsHandler) (*dataTrieLeafParser, error) { + if check.IfNil(marshaller) { + return nil, errors.ErrNilMarshalizer + } + if check.IfNil(enableEpochsHandler) { + return nil, errors.ErrNilEnableEpochsHandler + } + + return &dataTrieLeafParser{ + address: address, + marshaller: marshaller, + enableEpochsHandler: enableEpochsHandler, + }, nil +} + +// ParseLeaf returns a new KeyValStorage with the actual key and value +func (tlp *dataTrieLeafParser) ParseLeaf(trieKey []byte, trieVal []byte, version core.TrieNodeVersion) (core.KeyValueHolder, error) { + if tlp.enableEpochsHandler.IsAutoBalanceDataTriesEnabled() && version == core.AutoBalanceEnabled { + data := &dataTrieValue.TrieLeafData{} + err := tlp.marshaller.Unmarshal(data, trieVal) + if err != nil { + return nil, err + } + + return keyValStorage.NewKeyValStorage(data.Key, data.Value), nil + } + + suffix := append(trieKey, tlp.address...) + value, err := common.TrimSuffixFromValue(trieVal, len(suffix)) + if err != nil { + return nil, err + } + + return keyValStorage.NewKeyValStorage(trieKey, value), nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (tlp *dataTrieLeafParser) IsInterfaceNil() bool { + return tlp == nil +} diff --git a/state/parsers/dataTrieLeafParser_test.go b/state/parsers/dataTrieLeafParser_test.go new file mode 100644 index 00000000000..ba18aa0e6c0 --- /dev/null +++ b/state/parsers/dataTrieLeafParser_test.go @@ -0,0 +1,130 @@ +package parsers + +import ( + "encoding/hex" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/state/dataTrieValue" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/stretchr/testify/assert" +) + +func TestNewDataTrieLeafParser(t *testing.T) { + t.Parallel() + + t.Run("nil marshaller", func(t *testing.T) { + t.Parallel() + + tlp, err := NewDataTrieLeafParser([]byte("address"), nil, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + assert.True(t, check.IfNil(tlp)) + assert.Equal(t, errors.ErrNilMarshalizer, err) + }) + + t.Run("nil enableEpochsHandler", func(t *testing.T) { + t.Parallel() + + tlp, err := NewDataTrieLeafParser([]byte("address"), &marshallerMock.MarshalizerMock{}, nil) + assert.True(t, check.IfNil(tlp)) + assert.Equal(t, errors.ErrNilEnableEpochsHandler, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + tlp, err := NewDataTrieLeafParser([]byte("address"), &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + assert.Nil(t, err) + assert.False(t, check.IfNil(tlp)) + }) +} + +func TestTrieLeafParser_ParseLeaf(t *testing.T) { + t.Parallel() + + t.Run("auto balance disabled", func(t *testing.T) { + t.Parallel() + + key := []byte("key") + val := []byte("val") + address := []byte("address") + suffix := append(key, address...) + tlp, _ := NewDataTrieLeafParser(address, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + + keyVal, err := tlp.ParseLeaf(key, append(val, suffix...), core.NotSpecified) + assert.Nil(t, err) + assert.Equal(t, key, keyVal.Key()) + assert.Equal(t, val, keyVal.Value()) + }) + + t.Run("auto balance enabled - val with appended data", func(t *testing.T) { + t.Parallel() + + key := []byte("key") + val := []byte("val") + address := []byte("address") + suffix := append(key, address...) + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tlp, _ := NewDataTrieLeafParser(address, &marshallerMock.MarshalizerMock{}, enableEpochsHandler) + + keyVal, err := tlp.ParseLeaf(key, append(val, suffix...), core.NotSpecified) + assert.Nil(t, err) + assert.Equal(t, key, keyVal.Key()) + assert.Equal(t, val, keyVal.Value()) + }) + + t.Run("auto balance enabled - val as struct", func(t *testing.T) { + t.Parallel() + + marshaller := &marshallerMock.MarshalizerMock{} + hasher := &hashingMocks.HasherMock{} + key := []byte("key") + val := []byte("val") + address := []byte("address") + leafData := dataTrieValue.TrieLeafData{ + Value: val, + Key: key, + Address: address, + } + serializedLeafData, _ := marshaller.Marshal(leafData) + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tlp, _ := NewDataTrieLeafParser(address, marshaller, enableEpochsHandler) + + keyVal, err := tlp.ParseLeaf(hasher.Compute(string(key)), serializedLeafData, core.AutoBalanceEnabled) + assert.Nil(t, err) + assert.Equal(t, key, keyVal.Key()) + assert.Equal(t, val, keyVal.Value()) + }) + + t.Run("unmarshall bytes with appended data should not return empty data", func(t *testing.T) { + t.Parallel() + + marshaller := &marshal.GogoProtoMarshalizer{} + + keyBytes := []byte("eth") + valBytes := []byte("0xA2AA67319062488CAFfc7E52802a3308cAF78a54") + addrBytes, err := hex.DecodeString("b080fe7e47edd5f32b619a7a439a0174ebda49ac27a5b112dd685470ae008001") + assert.Nil(t, err) + + valWithAppendedData := append(valBytes, keyBytes...) + valWithAppendedData = append(valWithAppendedData, addrBytes...) + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tlp, _ := NewDataTrieLeafParser(addrBytes, marshaller, enableEpochsHandler) + + keyVal, err := tlp.ParseLeaf(keyBytes, valWithAppendedData, core.NotSpecified) + assert.Nil(t, err) + assert.Equal(t, keyBytes, keyVal.Key()) + assert.Equal(t, valBytes, keyVal.Value()) + }) +} diff --git a/state/parsers/mainTrieLeafParser.go b/state/parsers/mainTrieLeafParser.go new file mode 100644 index 00000000000..8835608fd7c --- /dev/null +++ b/state/parsers/mainTrieLeafParser.go @@ -0,0 +1,24 @@ +package parsers + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/keyValStorage" +) + +type mainTrieLeafParser struct { +} + +// NewMainTrieLeafParser creates a new instance of mainTrieLeafParser +func NewMainTrieLeafParser() *mainTrieLeafParser { + return &mainTrieLeafParser{} +} + +// ParseLeaf returns the given key an value as a KeyValStorage +func (tlp *mainTrieLeafParser) ParseLeaf(trieKey []byte, trieVal []byte, _ core.TrieNodeVersion) (core.KeyValueHolder, error) { + return keyValStorage.NewKeyValStorage(trieKey, trieVal), nil +} + +// IsInterfaceNil returns nil if there is no value under the interface +func (tlp *mainTrieLeafParser) IsInterfaceNil() bool { + return tlp == nil +} diff --git a/state/parsers/mainTrieLeafParser_test.go b/state/parsers/mainTrieLeafParser_test.go new file mode 100644 index 00000000000..fc94dcc8ae6 --- /dev/null +++ b/state/parsers/mainTrieLeafParser_test.go @@ -0,0 +1,32 @@ +package parsers + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" +) + +func TestNewMainTrieLeafParser(t *testing.T) { + t.Parallel() + + t.Run("new mainTrieLeafParser", func(t *testing.T) { + t.Parallel() + + assert.False(t, check.IfNil(NewMainTrieLeafParser())) + }) + + t.Run("parse leaf", func(t *testing.T) { + t.Parallel() + + key := []byte("key") + value := []byte("value") + dtlp := NewMainTrieLeafParser() + + keyValHolder, err := dtlp.ParseLeaf(key, value, core.NotSpecified) + assert.Nil(t, err) + assert.Equal(t, key, keyValHolder.Key()) + assert.Equal(t, value, keyValHolder.Value()) + }) +} diff --git a/state/peerAccount.go b/state/peerAccount.go index edc835199ee..9aa30583e9f 100644 --- a/state/peerAccount.go +++ b/state/peerAccount.go @@ -5,10 +5,12 @@ import ( "math/big" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/state/disabled" ) // PeerAccount is the struct used in serialization/deserialization type peerAccount struct { + //TODO investigate if *baseAccount is needed in peerAccount, and remove if not *baseAccount PeerAccountData } @@ -24,7 +26,7 @@ func NewEmptyPeerAccount() *peerAccount { } } -// NewPeerAccount creates new simple account wrapper for an PeerAccountContainer (that has just been initialized) +// NewPeerAccount creates a new instance of peerAccount func NewPeerAccount(address []byte) (*peerAccount, error) { if len(address) == 0 { return nil, ErrNilAddress @@ -33,7 +35,7 @@ func NewPeerAccount(address []byte) (*peerAccount, error) { return &peerAccount{ baseAccount: &baseAccount{ address: address, - dataTrieTracker: NewTrackableDataTrie(address, nil), + dataTrieTracker: disabled.NewDisabledTrackableDataTrie(), }, PeerAccountData: PeerAccountData{ AccumulatedFees: big.NewInt(0), diff --git a/state/snapshotStatistics.go b/state/snapshotStatistics.go index 366e9de5b8c..812a99ed263 100644 --- a/state/snapshotStatistics.go +++ b/state/snapshotStatistics.go @@ -48,11 +48,11 @@ func (ss *snapshotStatistics) WaitForSnapshotsToFinish() { } // AddTrieStats adds the given trie stats to the snapshot statistics -func (ss *snapshotStatistics) AddTrieStats(trieStats *statistics.TrieStatsDTO) { +func (ss *snapshotStatistics) AddTrieStats(trieStats common.TrieStatisticsHandler, trieType common.TrieType) { ss.mutex.Lock() defer ss.mutex.Unlock() - ss.trieStatisticsCollector.Add(trieStats) + ss.trieStatisticsCollector.Add(trieStats, trieType) } // WaitForSyncToFinish will wait until the waitGroup counter is zero @@ -88,3 +88,8 @@ func (ss *snapshotStatistics) PrintStats(identifier string, rootHash []byte) { func (ss *snapshotStatistics) GetSnapshotNumNodes() uint64 { return ss.trieStatisticsCollector.GetNumNodes() } + +// IsInterfaceNil returns true if there is no value under the interface +func (ss *snapshotStatistics) IsInterfaceNil() bool { + return ss == nil +} diff --git a/state/snapshotStatistics_test.go b/state/snapshotStatistics_test.go index 140cbc10856..4a739064571 100644 --- a/state/snapshotStatistics_test.go +++ b/state/snapshotStatistics_test.go @@ -21,7 +21,7 @@ func TestSnapshotStatistics_Concurrency(t *testing.T) { for i := 0; i < numRuns; i++ { ss.NewSnapshotStarted() go func() { - ss.AddTrieStats(getTrieStatsDTO(5, 60).GetTrieStats()) + ss.AddTrieStats(getTrieStatsDTO(5, 60), common.DataTrie) ss.SnapshotFinished() }() } diff --git a/state/storagePruningManager/storagePruningManager_test.go b/state/storagePruningManager/storagePruningManager_test.go index 1a1a8ace76e..4db053a9d27 100644 --- a/state/storagePruningManager/storagePruningManager_test.go +++ b/state/storagePruningManager/storagePruningManager_test.go @@ -9,7 +9,9 @@ import ( "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" @@ -23,24 +25,30 @@ func getDefaultTrieAndAccountsDbAndStoragePruningManager() (common.Trie, *state. SnapshotsBufferLen: 10, SnapshotsGoroutineNum: 1, } - marshaller := &testscommon.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} args := storage.GetStorageManagerArgs() args.CheckpointHashesHolder = hashesHolder.NewCheckpointHashesHolder(10000000, testscommon.HashSize) trieStorage, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorage, marshaller, hasher, 5) + tr, _ := trie.NewTrie(trieStorage, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, HashesSize: 10000, } ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) spm, _ := NewStoragePruningManager(ewl, generalCfg.PruningBufferLen) + argsAccCreator := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshaller, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + accCreator, _ := factory.NewAccountCreator(argsAccCreator) argsAccountsDB := state.ArgsAccountsDB{ Trie: tr, Hasher: hasher, Marshaller: marshaller, - AccountFactory: factory.NewAccountCreator(), + AccountFactory: accCreator, StoragePruningManager: spm, ProcessingMode: common.Normal, ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, diff --git a/state/syncer/baseAccountsSyncer.go b/state/syncer/baseAccountsSyncer.go index f31575562bb..a01f1155fed 100644 --- a/state/syncer/baseAccountsSyncer.go +++ b/state/syncer/baseAccountsSyncer.go @@ -33,6 +33,7 @@ type baseAccountsSyncer struct { checkNodesOnDisk bool userAccountsSyncStatisticsHandler common.SizeSyncStatisticsHandler appStatusHandler core.AppStatusHandler + enableEpochsHandler common.EnableEpochsHandler trieSyncerVersion int numTriesSynced int32 @@ -51,6 +52,7 @@ type ArgsNewBaseAccountsSyncer struct { Cacher storage.Cacher UserAccountsSyncStatisticsHandler common.SizeSyncStatisticsHandler AppStatusHandler core.AppStatusHandler + EnableEpochsHandler common.EnableEpochsHandler MaxTrieLevelInMemory uint MaxHardCapForMissingNodes int TrieSyncerVersion int @@ -79,6 +81,9 @@ func checkArgs(args ArgsNewBaseAccountsSyncer) error { if check.IfNil(args.AppStatusHandler) { return state.ErrNilAppStatusHandler } + if check.IfNil(args.EnableEpochsHandler) { + return state.ErrNilEnableEpochsHandler + } if args.MaxHardCapForMissingNodes < 1 { return state.ErrInvalidMaxHardCapForMissingNodes } @@ -210,7 +215,7 @@ func (b *baseAccountsSyncer) GetSyncedTries() map[string]common.Trie { b.mutex.Lock() defer b.mutex.Unlock() - dataTrie, err := trie.NewTrie(b.trieStorageManager, b.marshalizer, b.hasher, b.maxTrieLevelInMemory) + dataTrie, err := trie.NewTrie(b.trieStorageManager, b.marshalizer, b.hasher, b.enableEpochsHandler, b.maxTrieLevelInMemory) if err != nil { log.Warn("error creating a new trie in baseAccountsSyncer.GetSyncedTries", "error", err) return make(map[string]common.Trie) diff --git a/state/syncer/baseAccoutnsSyncer_test.go b/state/syncer/baseAccoutnsSyncer_test.go index de71219d74b..da3819b05ce 100644 --- a/state/syncer/baseAccoutnsSyncer_test.go +++ b/state/syncer/baseAccoutnsSyncer_test.go @@ -7,7 +7,9 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storageManager" "github.com/stretchr/testify/require" @@ -16,13 +18,14 @@ import ( func getDefaultBaseAccSyncerArgs() syncer.ArgsNewBaseAccountsSyncer { return syncer.ArgsNewBaseAccountsSyncer{ Hasher: &hashingMocks.HasherMock{}, - Marshalizer: testscommon.MarshalizerMock{}, + Marshalizer: marshallerMock.MarshalizerMock{}, TrieStorageManager: &storageManager.StorageManagerStub{}, RequestHandler: &testscommon.RequestHandlerStub{}, Timeout: time.Second, Cacher: testscommon.NewCacherMock(), UserAccountsSyncStatisticsHandler: &testscommon.SizeSyncStatisticsHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, MaxTrieLevelInMemory: 5, MaxHardCapForMissingNodes: 100, TrieSyncerVersion: 3, diff --git a/state/syncer/errors.go b/state/syncer/errors.go index 5a12356ecf8..7bfea7e1673 100644 --- a/state/syncer/errors.go +++ b/state/syncer/errors.go @@ -7,3 +7,6 @@ var ErrNilPubkeyConverter = errors.New("nil pubkey converter") // ErrNilStorageMarker signals that a nil storage marker was provided var ErrNilStorageMarker = errors.New("nil storage marker") + +// ErrNilEnableEpochsHandler signals that a nil enable epochs handler was provided +var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") diff --git a/state/syncer/userAccountSyncer_test.go b/state/syncer/userAccountSyncer_test.go new file mode 100644 index 00000000000..eefdd96778f --- /dev/null +++ b/state/syncer/userAccountSyncer_test.go @@ -0,0 +1,113 @@ +package syncer + +import ( + "testing" + "time" + + "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" + "github.com/multiversx/mx-chain-go/testscommon/storageManager" + "github.com/multiversx/mx-chain-go/trie" + "github.com/stretchr/testify/assert" +) + +// TODO add more tests + +func getDefaultBaseAccSyncerArgs() ArgsNewBaseAccountsSyncer { + return ArgsNewBaseAccountsSyncer{ + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: marshallerMock.MarshalizerMock{}, + TrieStorageManager: &storageManager.StorageManagerStub{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + Timeout: time.Second, + Cacher: testscommon.NewCacherMock(), + UserAccountsSyncStatisticsHandler: &testscommon.SizeSyncStatisticsHandlerStub{}, + AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, + MaxTrieLevelInMemory: 0, + MaxHardCapForMissingNodes: 100, + TrieSyncerVersion: 2, + CheckNodesOnDisk: false, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } +} + +func TestUserAccountsSyncer_SyncAccounts(t *testing.T) { + t.Parallel() + + args := ArgsNewUserAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), + ShardId: 0, + Throttler: &mock.ThrottlerStub{}, + AddressPubKeyConverter: &testscommon.PubkeyConverterStub{}, + } + syncer, err := NewUserAccountsSyncer(args) + assert.Nil(t, err) + assert.NotNil(t, syncer) + + err = syncer.SyncAccounts([]byte("rootHash"), nil) + assert.Equal(t, ErrNilStorageMarker, err) +} + +func TestUserAccountsSyncer_MissingDataTrieNodeFound(t *testing.T) { + t.Parallel() + + numNodesSynced := 0 + numProcessedCalled := 0 + setNumMissingCalled := 0 + args := ArgsNewUserAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), + ShardId: 0, + Throttler: &mock.ThrottlerStub{}, + AddressPubKeyConverter: &testscommon.PubkeyConverterStub{}, + } + args.TrieStorageManager = &storageManager.StorageManagerStub{ + PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { + numNodesSynced++ + return nil + }, + } + args.UserAccountsSyncStatisticsHandler = &testscommon.SizeSyncStatisticsHandlerStub{ + AddNumProcessedCalled: func(value int) { + numProcessedCalled++ + }, + SetNumMissingCalled: func(rootHash []byte, value int) { + setNumMissingCalled++ + assert.Equal(t, 0, value) + }, + } + + var serializedLeafNode []byte + tsm := &storageManager.StorageManagerStub{ + PutCalled: func(key []byte, val []byte) error { + serializedLeafNode = val + return nil + }, + } + + tr, _ := trie.NewTrie(tsm, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + key := []byte("key") + value := []byte("value") + _ = tr.Update(key, value) + rootHash, _ := tr.RootHash() + _ = tr.Commit() + + args.Cacher = &testscommon.CacherStub{ + GetCalled: func(key []byte) (value interface{}, ok bool) { + interceptedNode, _ := trie.NewInterceptedTrieNode(serializedLeafNode, args.Hasher) + return interceptedNode, true + }, + } + + syncer, _ := NewUserAccountsSyncer(args) + // test that timeout watchdog is reset + time.Sleep(args.Timeout * 2) + syncer.MissingDataTrieNodeFound(rootHash) + + assert.Equal(t, 1, numNodesSynced) + assert.Equal(t, 1, numProcessedCalled) + assert.Equal(t, 1, setNumMissingCalled) +} diff --git a/state/syncer/userAccountsSyncer.go b/state/syncer/userAccountsSyncer.go index f503849f943..283e3c25b3e 100644 --- a/state/syncer/userAccountsSyncer.go +++ b/state/syncer/userAccountsSyncer.go @@ -66,6 +66,9 @@ func NewUserAccountsSyncer(args ArgsNewUserAccountsSyncer) (*userAccountsSyncer, if check.IfNil(args.AddressPubKeyConverter) { return nil, ErrNilPubkeyConverter } + if check.IfNil(args.EnableEpochsHandler) { + return nil, ErrNilEnableEpochsHandler + } timeoutHandler, err := common.NewTimeoutHandler(args.Timeout) if err != nil { @@ -88,6 +91,7 @@ func NewUserAccountsSyncer(args ArgsNewUserAccountsSyncer) (*userAccountsSyncer, checkNodesOnDisk: args.CheckNodesOnDisk, userAccountsSyncStatisticsHandler: args.UserAccountsSyncStatisticsHandler, appStatusHandler: args.AppStatusHandler, + enableEpochsHandler: args.EnableEpochsHandler, } u := &userAccountsSyncer{ @@ -242,12 +246,15 @@ func (u *userAccountsSyncer) syncAccountDataTries( defer u.printDataTrieStatistics() wg := sync.WaitGroup{} - + argsAccCreation := state.ArgsAccountCreation{ + Hasher: u.hasher, + Marshaller: u.marshalizer, + EnableEpochsHandler: u.enableEpochsHandler, + } for leaf := range leavesChannels.LeavesChan { u.resetTimeoutHandlerWatchdog() - account := state.NewEmptyUserAccount() - err := u.marshalizer.Unmarshal(account, leaf.Value()) + account, err := state.NewUserAccountFromBytes(leaf.Value(), argsAccCreation) if err != nil { log.Trace("this must be a leaf with code", "leaf key", leaf.Key(), "err", err) continue diff --git a/state/syncer/userAccountsSyncer_test.go b/state/syncer/userAccountsSyncer_test.go index 8f1ca462be3..f6036c110c0 100644 --- a/state/syncer/userAccountsSyncer_test.go +++ b/state/syncer/userAccountsSyncer_test.go @@ -15,8 +15,12 @@ import ( "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/storageManager" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/hashesHolder" @@ -35,6 +39,14 @@ func getDefaultUserAccountsSyncerArgs() syncer.ArgsNewUserAccountsSyncer { } } +func getDefaultArgsAccountCreation() state.ArgsAccountCreation { + return state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } +} + func TestNewUserAccountsSyncer(t *testing.T) { t.Parallel() @@ -105,7 +117,7 @@ func getSerializedTrieNode( }, } - tr, _ := trie.NewTrie(tsm, marshaller, hasher, 5) + tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) _ = tr.Update(key, []byte("value")) _ = tr.Commit() @@ -156,7 +168,7 @@ func TestUserAccountsSyncer_SyncAccounts(t *testing.T) { }) } -func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, uint) { +func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, common.EnableEpochsHandler, uint) { marshalizer := &testscommon.ProtobufMarshalizerMock{} hasher := &testscommon.KeccakMock{} @@ -180,7 +192,7 @@ func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, has trieStorageManager, _ := trie.NewTrieStorageManager(args) maxTrieLevelInMemory := uint(1) - return trieStorageManager, args.Marshalizer, args.Hasher, maxTrieLevelInMemory + return trieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory } func emptyTrie() common.Trie { @@ -232,10 +244,10 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { s, err := syncer.NewUserAccountsSyncer(args) require.Nil(t, err) - _, _ = trie.NewTrie(args.TrieStorageManager, args.Marshalizer, args.Hasher, 5) + _, _ = trie.NewTrie(args.TrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) tr := emptyTrie() - account, err := state.NewUserAccount(testscommon.TestPubKeyAlice) + account, err := state.NewUserAccount(testscommon.TestPubKeyAlice, getDefaultArgsAccountCreation()) require.Nil(t, err) account.SetRootHash(key) @@ -255,7 +267,7 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { rootHash, err := tr.RootHash() require.Nil(t, err) - err = tr.GetAllLeavesOnChannel(leavesChannels, context.TODO(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + err = tr.GetAllLeavesOnChannel(leavesChannels, context.TODO(), rootHash, keyBuilder.NewDisabledKeyBuilder(), parsers.NewMainTrieLeafParser()) require.Nil(t, err) ctx, cancel := context.WithCancel(context.TODO()) @@ -289,10 +301,10 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { s, err := syncer.NewUserAccountsSyncer(args) require.Nil(t, err) - _, _ = trie.NewTrie(args.TrieStorageManager, args.Marshalizer, args.Hasher, 5) + _, _ = trie.NewTrie(args.TrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) tr := emptyTrie() - account, err := state.NewUserAccount(testscommon.TestPubKeyAlice) + account, err := state.NewUserAccount(testscommon.TestPubKeyAlice, getDefaultArgsAccountCreation()) require.Nil(t, err) account.SetRootHash(key) @@ -312,7 +324,7 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { rootHash, err := tr.RootHash() require.Nil(t, err) - err = tr.GetAllLeavesOnChannel(leavesChannels, context.TODO(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + err = tr.GetAllLeavesOnChannel(leavesChannels, context.TODO(), rootHash, keyBuilder.NewDisabledKeyBuilder(), parsers.NewMainTrieLeafParser()) require.Nil(t, err) err = s.SyncAccountDataTries(leavesChannels, context.TODO()) @@ -356,7 +368,7 @@ func TestUserAccountsSyncer_MissingDataTrieNodeFound(t *testing.T) { }, } - tr, _ := trie.NewTrie(tsm, args.Marshalizer, args.Hasher, 5) + tr, _ := trie.NewTrie(tsm, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) key := []byte("key") value := []byte("value") _ = tr.Update(key, value) diff --git a/state/syncer/validatorAccountsSyncer.go b/state/syncer/validatorAccountsSyncer.go index db70df18930..943368441d4 100644 --- a/state/syncer/validatorAccountsSyncer.go +++ b/state/syncer/validatorAccountsSyncer.go @@ -50,6 +50,7 @@ func NewValidatorAccountsSyncer(args ArgsNewValidatorAccountsSyncer) (*validator checkNodesOnDisk: args.CheckNodesOnDisk, userAccountsSyncStatisticsHandler: statistics.NewTrieSyncStatistics(), appStatusHandler: args.AppStatusHandler, + enableEpochsHandler: args.EnableEpochsHandler, } u := &validatorAccountsSyncer{ diff --git a/state/trackableDataTrie.go b/state/trackableDataTrie.go index 09c0672c742..3f30fc06654 100644 --- a/state/trackableDataTrie.go +++ b/state/trackableDataTrie.go @@ -1,81 +1,187 @@ package state import ( + "fmt" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + errorsCommon "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/state/dataTrieValue" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) +type dirtyData struct { + value []byte + newVersion core.TrieNodeVersion +} + // TrackableDataTrie wraps a PatriciaMerkelTrie adding modifying data capabilities type trackableDataTrie struct { - dirtyData map[string][]byte - tr common.Trie - identifier []byte + dirtyData map[string]dirtyData + tr common.Trie + hasher hashing.Hasher + marshaller marshal.Marshalizer + enableEpochsHandler common.EnableEpochsHandler + identifier []byte } // NewTrackableDataTrie returns an instance of trackableDataTrie -func NewTrackableDataTrie(identifier []byte, tr common.Trie) *trackableDataTrie { - return &trackableDataTrie{ - tr: tr, - dirtyData: make(map[string][]byte), - identifier: identifier, +func NewTrackableDataTrie( + identifier []byte, + tr common.Trie, + hasher hashing.Hasher, + marshaller marshal.Marshalizer, + enableEpochsHandler common.EnableEpochsHandler, +) (*trackableDataTrie, error) { + if check.IfNil(hasher) { + return nil, ErrNilHasher } + if check.IfNil(marshaller) { + return nil, ErrNilMarshalizer + } + if check.IfNil(enableEpochsHandler) { + return nil, ErrNilEnableEpochsHandler + } + + return &trackableDataTrie{ + tr: tr, + hasher: hasher, + marshaller: marshaller, + dirtyData: make(map[string]dirtyData), + identifier: identifier, + enableEpochsHandler: enableEpochsHandler, + }, nil } // RetrieveValue fetches the value from a particular key searching the account data store // The search starts with dirty map, continues with original map and ends with the trie // Data must have been retrieved from its trie func (tdaw *trackableDataTrie) RetrieveValue(key []byte) ([]byte, uint32, error) { - tailLength := len(key) + len(tdaw.identifier) - // search in dirty data cache - if value, found := tdaw.dirtyData[string(key)]; found { - log.Trace("retrieve value from dirty data", "key", key, "value", value) - trimmedVal, err := trimValue(value, tailLength) - return trimmedVal, 0, err + if dataEntry, found := tdaw.dirtyData[string(key)]; found { + log.Trace("retrieve value from dirty data", "key", key, "value", dataEntry.value, "account", tdaw.identifier) + return dataEntry.value, 0, nil } // ok, not in cache, retrieve from trie - if tdaw.tr == nil { + if check.IfNil(tdaw.tr) { return nil, 0, ErrNilTrie } - value, depth, err := tdaw.tr.Get(key) + trieValue, depth, err := tdaw.retrieveValueFromTrie(key) if err != nil { return nil, depth, err } - log.Trace("retrieve value from trie", "key", key, "value", value, "depth", depth) - value, _ = trimValue(value, tailLength) - - return value, depth, nil -} -func trimValue(value []byte, tailLength int) ([]byte, error) { - dataLength := len(value) - tailLength - if dataLength < 0 { - return nil, ErrNegativeValue + val, err := tdaw.getValueWithoutMetadata(key, trieValue) + if err != nil { + return nil, depth, err } - return value[:dataLength], nil + log.Trace("retrieve value from trie", "key", key, "value", val, "account", tdaw.identifier) + + return val, depth, nil } // SaveKeyValue stores in dirtyData the data keys "touched" // It does not care if the data is really dirty as calling this check here will be sub-optimal func (tdaw *trackableDataTrie) SaveKeyValue(key []byte, value []byte) error { - var identifier []byte - lenValue := uint64(len(value)) - if lenValue > core.MaxLeafSize { + if uint64(len(value)) > core.MaxLeafSize { return data.ErrLeafSizeTooBig } - if lenValue != 0 { - identifier = append(key, tdaw.identifier...) + dataEntry := dirtyData{ + value: value, + newVersion: core.GetVersionForNewData(tdaw.enableEpochsHandler), } - tdaw.dirtyData[string(key)] = append(value, identifier...) + tdaw.dirtyData[string(key)] = dataEntry return nil } +// MigrateDataTrieLeaves migrates the data trie leaves from oldVersion to newVersion +func (tdaw *trackableDataTrie) MigrateDataTrieLeaves(args vmcommon.ArgsMigrateDataTrieLeaves) error { + if check.IfNil(tdaw.tr) { + return ErrNilTrie + } + if check.IfNil(args.TrieMigrator) { + return errorsCommon.ErrNilTrieMigrator + } + + dtr, ok := tdaw.tr.(dataTrie) + if !ok { + return fmt.Errorf("invalid trie, type is %T", tdaw.tr) + } + + err := dtr.CollectLeavesForMigration(args) + if err != nil { + return err + } + + dataToBeMigrated := args.TrieMigrator.GetLeavesToBeMigrated() + for _, leafData := range dataToBeMigrated { + dataEntry := dirtyData{ + value: leafData.Value, + newVersion: args.NewVersion, + } + + originalKey, err := tdaw.getOriginalKeyFromTrieData(leafData) + if err != nil { + return err + } + + tdaw.dirtyData[string(originalKey)] = dataEntry + } + + return nil +} + +func (tdaw *trackableDataTrie) getOriginalKeyFromTrieData(trieData core.TrieData) ([]byte, error) { + if trieData.Version == core.AutoBalanceEnabled { + valWithMetadata := &dataTrieValue.TrieLeafData{} + err := tdaw.marshaller.Unmarshal(valWithMetadata, trieData.Value) + if err != nil { + return nil, err + } + + return valWithMetadata.Key, nil + } + + return trieData.Key, nil +} + +func (tdaw *trackableDataTrie) getKeyForVersion(key []byte, version core.TrieNodeVersion) []byte { + if version == core.AutoBalanceEnabled { + return tdaw.hasher.Compute(string(key)) + } + + return key +} + +func (tdaw *trackableDataTrie) getValueForVersion(key []byte, value []byte, version core.TrieNodeVersion) ([]byte, error) { + if len(value) == 0 { + return nil, nil + } + + if version == core.AutoBalanceEnabled { + trieVal := &dataTrieValue.TrieLeafData{ + Value: value, + Key: key, + Address: tdaw.identifier, + } + + return tdaw.marshaller.Marshal(trieVal) + } + + identifier := append(key, tdaw.identifier...) + valueWithAppendedData := append(value, identifier...) + + return valueWithAppendedData, nil +} + // SetDataTrie sets the internal data trie func (tdaw *trackableDataTrie) SetDataTrie(tr common.Trie) { tdaw.tr = tr @@ -87,9 +193,9 @@ func (tdaw *trackableDataTrie) DataTrie() common.DataTrieHandler { } // SaveDirtyData saved the dirty data to the trie -func (tdaw *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) (map[string][]byte, error) { +func (tdaw *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]core.TrieData, error) { if len(tdaw.dirtyData) == 0 { - return map[string][]byte{}, nil + return make([]core.TrieData, 0), nil } if check.IfNil(tdaw.tr) { @@ -101,26 +207,160 @@ func (tdaw *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) (map[string][ tdaw.tr = newDataTrie } - oldValues := make(map[string][]byte) + dtr, ok := tdaw.tr.(dataTrie) + if !ok { + return nil, fmt.Errorf("invalid trie, type is %T", tdaw.tr) + } - for k, v := range tdaw.dirtyData { - val, _, err := tdaw.tr.Get([]byte(k)) + return tdaw.updateTrie(dtr) +} + +func (tdaw *trackableDataTrie) updateTrie(dtr dataTrie) ([]core.TrieData, error) { + oldValues := make([]core.TrieData, len(tdaw.dirtyData)) + + index := 0 + for key, dataEntry := range tdaw.dirtyData { + oldVal, _, err := tdaw.retrieveValueFromTrie([]byte(key)) if err != nil { - return oldValues, err + return nil, err } + oldValues[index] = oldVal - oldValues[k] = val + err = tdaw.deleteOldEntryIfMigrated([]byte(key), dataEntry, oldVal) + if err != nil { + return nil, err + } - err = tdaw.tr.Update([]byte(k), v) + err = tdaw.modifyTrie([]byte(key), dataEntry, oldVal, dtr) if err != nil { - return oldValues, err + return nil, err } + + index++ } - tdaw.dirtyData = make(map[string][]byte) + tdaw.dirtyData = make(map[string]dirtyData) + return oldValues, nil } +func (tdaw *trackableDataTrie) retrieveValueFromTrie(key []byte) (core.TrieData, uint32, error) { + if tdaw.enableEpochsHandler.IsAutoBalanceDataTriesEnabled() { + hashedKey := tdaw.hasher.Compute(string(key)) + valWithMetadata, depth, err := tdaw.tr.Get(hashedKey) + if err != nil { + return core.TrieData{}, 0, err + } + if len(valWithMetadata) != 0 { + trieValue := core.TrieData{ + Key: hashedKey, + Value: valWithMetadata, + Version: core.AutoBalanceEnabled, + } + + return trieValue, depth, nil + } + } + + valWithMetadata, depth, err := tdaw.tr.Get(key) + if err != nil { + return core.TrieData{}, 0, err + } + if len(valWithMetadata) != 0 { + trieValue := core.TrieData{ + Key: key, + Value: valWithMetadata, + Version: core.NotSpecified, + } + + return trieValue, depth, nil + } + + newDataVersion := core.GetVersionForNewData(tdaw.enableEpochsHandler) + keyForTrie := tdaw.getKeyForVersion(key, newDataVersion) + + trieValue := core.TrieData{ + Key: keyForTrie, + Value: nil, + Version: newDataVersion, + } + + return trieValue, depth, nil +} + +func (tdaw *trackableDataTrie) getValueWithoutMetadata(key []byte, trieData core.TrieData) ([]byte, error) { + if len(trieData.Value) == 0 { + return nil, nil + } + + if trieData.Version == core.AutoBalanceEnabled { + return tdaw.getValueAutoBalanceVersion(trieData.Value) + } + + return tdaw.getValueNotSpecifiedVersion(key, trieData.Value) +} + +func (tdaw *trackableDataTrie) getValueAutoBalanceVersion(val []byte) ([]byte, error) { + dataTrieVal := &dataTrieValue.TrieLeafData{} + err := tdaw.marshaller.Unmarshal(dataTrieVal, val) + if err != nil { + return nil, err + } + + return dataTrieVal.Value, nil +} + +func (tdaw *trackableDataTrie) getValueNotSpecifiedVersion(key []byte, val []byte) ([]byte, error) { + tailLength := len(key) + len(tdaw.identifier) + trimmedValue, _ := common.TrimSuffixFromValue(val, tailLength) + + return trimmedValue, nil +} + +func (tdaw *trackableDataTrie) deleteOldEntryIfMigrated(key []byte, newData dirtyData, oldEntry core.TrieData) error { + if !tdaw.enableEpochsHandler.IsAutoBalanceDataTriesEnabled() { + return nil + } + + isMigration := oldEntry.Version == core.NotSpecified && newData.newVersion == core.AutoBalanceEnabled + if isMigration && len(newData.value) != 0 { + return tdaw.tr.Delete(key) + } + + return nil +} + +func (tdaw *trackableDataTrie) modifyTrie(key []byte, dataEntry dirtyData, oldVal core.TrieData, dtr dataTrie) error { + if len(dataEntry.value) == 0 { + return tdaw.deleteFromTrie(oldVal, key, dtr) + } + + version := dataEntry.newVersion + newKey := tdaw.getKeyForVersion(key, version) + value, err := tdaw.getValueForVersion(key, dataEntry.value, version) + if err != nil { + return err + } + + return dtr.UpdateWithVersion(newKey, value, version) +} + +func (tdaw *trackableDataTrie) deleteFromTrie(oldVal core.TrieData, key []byte, dtr dataTrie) error { + if len(oldVal.Value) == 0 { + return nil + } + + if oldVal.Version == core.AutoBalanceEnabled { + return dtr.Delete(tdaw.hasher.Compute(string(key))) + } + + if oldVal.Version == core.NotSpecified { + return dtr.Delete(key) + } + + return nil +} + // IsInterfaceNil returns true if there is no value under the interface func (tdaw *trackableDataTrie) IsInterfaceNil() bool { return tdaw == nil diff --git a/state/trackableDataTrie_test.go b/state/trackableDataTrie_test.go index 33bedfc0266..a3ed92d535a 100644 --- a/state/trackableDataTrie_test.go +++ b/state/trackableDataTrie_test.go @@ -7,8 +7,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/common" + errorsCommon "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/dataTrieValue" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -16,152 +23,846 @@ import ( func TestNewTrackableDataTrie(t *testing.T) { t.Parallel() - identifier := []byte("identifier") - trie := &trieMock.TrieStub{} - tdaw := state.NewTrackableDataTrie(identifier, trie) + t.Run("create with nil hasher", func(t *testing.T) { + t.Parallel() - assert.False(t, check.IfNil(tdaw)) -} + tdt, err := state.NewTrackableDataTrie([]byte("identifier"), &trieMock.TrieStub{}, nil, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + assert.Equal(t, state.ErrNilHasher, err) + assert.True(t, check.IfNil(tdt)) + }) -func TestTrackableDataTrie_RetrieveValueNilDataTrieShouldErr(t *testing.T) { - t.Parallel() + t.Run("create with nil marshaller", func(t *testing.T) { + t.Parallel() - as := state.NewTrackableDataTrie([]byte("identifier"), nil) - assert.NotNil(t, as) + tdt, err := state.NewTrackableDataTrie([]byte("identifier"), &trieMock.TrieStub{}, &hashingMocks.HasherMock{}, nil, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + assert.Equal(t, state.ErrNilMarshalizer, err) + assert.True(t, check.IfNil(tdt)) + }) - _, trieDepth, err := as.RetrieveValue([]byte("ABC")) - assert.NotNil(t, err) - assert.Equal(t, uint32(0), trieDepth) -} + t.Run("create with nil enableEpochsHandler", func(t *testing.T) { + t.Parallel() -func TestTrackableDataTrie_RetrieveValueFoundInTrieShouldWork(t *testing.T) { - t.Parallel() + tdt, err := state.NewTrackableDataTrie([]byte("identifier"), &trieMock.TrieStub{}, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, nil) + assert.Equal(t, state.ErrNilEnableEpochsHandler, err) + assert.True(t, check.IfNil(tdt)) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() - identifier := []byte("identifier") - expectedKey := []byte("key") - - expectedVal := []byte("value") - value := append(expectedVal, expectedKey...) - value = append(value, identifier...) - expectedTrieDepth := uint32(5) - - trie := &trieMock.TrieStub{ - UpdateCalled: func(key, value []byte) error { - return nil - }, - GetCalled: func(key []byte) ([]byte, uint32, error) { - if bytes.Equal(key, expectedKey) { - return value, expectedTrieDepth, nil - } - return nil, 0, nil - }, - } - mdaw := state.NewTrackableDataTrie(identifier, trie) - assert.NotNil(t, mdaw) - - valRecovered, trieDepth, err := mdaw.RetrieveValue(expectedKey) - assert.Nil(t, err) - assert.Equal(t, expectedVal, valRecovered) - assert.Equal(t, expectedTrieDepth, trieDepth) + tdt, err := state.NewTrackableDataTrie([]byte("identifier"), &trieMock.TrieStub{}, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + assert.Nil(t, err) + assert.False(t, check.IfNil(tdt)) + }) } -func TestTrackableDataTrie_RetrieveValueMalfunctionTrieShouldErr(t *testing.T) { +func TestTrackableDataTrie_SaveKeyValue(t *testing.T) { t.Parallel() - errExpected := errors.New("expected err") - keyExpected := []byte("key") - trie := &trieMock.TrieStub{ - UpdateCalled: func(key, value []byte) error { - return nil - }, - GetCalled: func(_ []byte) ([]byte, uint32, error) { - return nil, 0, errExpected - }, - } - mdaw := state.NewTrackableDataTrie([]byte("identifier"), trie) - assert.NotNil(t, mdaw) - - valRecovered, _, err := mdaw.RetrieveValue(keyExpected) - assert.Equal(t, errExpected, err) - assert.Nil(t, valRecovered) + t.Run("data too large", func(t *testing.T) { + t.Parallel() + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), &trieMock.TrieStub{}, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + + err := tdt.SaveKeyValue([]byte("key"), make([]byte, core.MaxLeafSize+1)) + assert.Equal(t, err, data.ErrLeafSizeTooBig) + }) + + t.Run("should save given val only in dirty data", func(t *testing.T) { + t.Parallel() + + keyExpected := []byte("key") + value := []byte("value") + trie := &trieMock.TrieStub{ + UpdateCalled: func(key, value []byte) error { + assert.Fail(t, "should not have saved directly in the trie") + return nil + }, + GetCalled: func(key []byte) ([]byte, uint32, error) { + assert.Fail(t, "should not have saved directly in the trie") + return nil, 0, nil + }, + } + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + assert.NotNil(t, tdt) + + _ = tdt.SaveKeyValue(keyExpected, value) + + dirtyData := tdt.DirtyData() + assert.Equal(t, 1, len(dirtyData)) + assert.Equal(t, value, dirtyData[string(keyExpected)].Value) + }) } -func TestTrackableDataTrie_RetrieveValueShouldCheckDirtyDataFirst(t *testing.T) { +func TestTrackableDataTrie_RetrieveValue(t *testing.T) { t.Parallel() - identifier := []byte("id") - key := []byte("key") - tail := append(key, identifier...) - retrievedTrieVal := []byte("value") - trieValue := append(retrievedTrieVal, tail...) - newTrieValue := []byte("new trie value") - expectedTrieDepth := uint32(5) - - trie := &trieMock.TrieStub{ - GetCalled: func(_ []byte) ([]byte, uint32, error) { - return trieValue, expectedTrieDepth, nil - }, - } - mdaw := state.NewTrackableDataTrie([]byte("id"), trie) - assert.NotNil(t, mdaw) - - valRecovered, trieDepth, err := mdaw.RetrieveValue(key) - assert.Equal(t, retrievedTrieVal, valRecovered) - assert.Nil(t, err) - assert.Equal(t, expectedTrieDepth, trieDepth) - - _ = mdaw.SaveKeyValue(key, newTrieValue) - valRecovered, trieDepth, err = mdaw.RetrieveValue(key) - assert.Equal(t, newTrieValue, valRecovered) - assert.Nil(t, err) - assert.Equal(t, uint32(0), trieDepth) + t.Run("should check dirty data first", func(t *testing.T) { + t.Parallel() + + identifier := []byte("id") + key := []byte("key") + tail := append(key, identifier...) + retrievedTrieVal := []byte("value") + trieValue := append(retrievedTrieVal, tail...) + newTrieValue := []byte("new trie value") + + trie := &trieMock.TrieStub{ + GetCalled: func(trieKey []byte) ([]byte, uint32, error) { + if bytes.Equal(trieKey, key) { + return trieValue, 0, nil + } + return nil, 0, nil + }, + } + tdt, _ := state.NewTrackableDataTrie(identifier, trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + assert.NotNil(t, tdt) + + valRecovered, _, err := tdt.RetrieveValue(key) + assert.Equal(t, retrievedTrieVal, valRecovered) + assert.Nil(t, err) + + _ = tdt.SaveKeyValue(key, newTrieValue) + valRecovered, _, err = tdt.RetrieveValue(key) + assert.Equal(t, newTrieValue, valRecovered) + assert.Nil(t, err) + }) + + t.Run("nil data trie should err", func(t *testing.T) { + t.Parallel() + + tdt, err := state.NewTrackableDataTrie([]byte("identifier"), nil, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + assert.Nil(t, err) + assert.NotNil(t, tdt) + + _, _, err = tdt.RetrieveValue([]byte("ABC")) + assert.Equal(t, state.ErrNilTrie, err) + }) + + t.Run("val with appended data found in trie", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + expectedKey := []byte("key") + expectedVal := []byte("value") + value := append(expectedVal, expectedKey...) + value = append(value, identifier...) + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, expectedKey) { + return value, 0, nil + } + return nil, 0, nil + }, + } + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tdt, _ := state.NewTrackableDataTrie(identifier, trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpochsHandler) + assert.NotNil(t, tdt) + + valRecovered, _, err := tdt.RetrieveValue(expectedKey) + assert.Nil(t, err) + assert.Equal(t, expectedVal, valRecovered) + }) + + t.Run("autoBalance data tries disabled", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + expectedKey := []byte("key") + expectedVal := []byte("value") + value := append(expectedVal, expectedKey...) + value = append(value, identifier...) + hasher := &hashingMocks.HasherMock{} + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, expectedKey) { + return value, 0, nil + } + if bytes.Equal(key, hasher.Compute(string(expectedKey))) { + assert.Fail(t, "this should not have been called") + } + return nil, 0, nil + }, + } + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: false, + } + tdt, _ := state.NewTrackableDataTrie(identifier, trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpochsHandler) + assert.NotNil(t, tdt) + + valRecovered, _, err := tdt.RetrieveValue(expectedKey) + assert.Nil(t, err) + assert.Equal(t, expectedVal, valRecovered) + }) + + t.Run("val as struct found in trie", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + expectedKey := []byte("key") + expectedVal := []byte("value") + hasher := &hashingMocks.HasherMock{} + marshaller := &marshallerMock.MarshalizerMock{} + + trie := &trieMock.TrieStub{ + UpdateCalled: func(key, value []byte) error { + return nil + }, + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, hasher.Compute(string(expectedKey))) { + serializedVal, _ := marshaller.Marshal(&dataTrieValue.TrieLeafData{ + Value: expectedVal, + Key: expectedKey, + Address: identifier, + }) + return serializedVal, 0, nil + } + return nil, 0, nil + }, + } + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tdt, _ := state.NewTrackableDataTrie(identifier, trie, hasher, marshaller, enableEpochsHandler) + assert.NotNil(t, tdt) + + valRecovered, _, err := tdt.RetrieveValue(expectedKey) + assert.Nil(t, err) + assert.Equal(t, expectedVal, valRecovered) + }) + + t.Run("trie malfunction should err", func(t *testing.T) { + t.Parallel() + + errExpected := errors.New("expected err") + keyExpected := []byte("key") + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + return nil, 0, errExpected + }, + } + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + assert.NotNil(t, tdt) + + valRecovered, _, err := tdt.RetrieveValue(keyExpected) + assert.Equal(t, errExpected, err) + assert.Nil(t, valRecovered) + }) + + t.Run("val not found in trie - auto balance enabled", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + expectedKey := []byte("key") + hasher := &hashingMocks.HasherMock{} + marshaller := &marshallerMock.MarshalizerMock{} + + trie := &trieMock.TrieStub{ + UpdateCalled: func(key, value []byte) error { + return nil + }, + GetCalled: func(key []byte) ([]byte, uint32, error) { + return nil, 0, nil + }, + } + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tdt, _ := state.NewTrackableDataTrie( + identifier, + trie, + hasher, + marshaller, + enableEpochsHandler, + ) + assert.NotNil(t, tdt) + + valRecovered, _, err := tdt.RetrieveValue(expectedKey) + assert.Nil(t, err) + assert.Equal(t, []byte(nil), valRecovered) + }) + + t.Run("val not found in trie - auto balance disabled", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + expectedKey := []byte("key") + hasher := &hashingMocks.HasherMock{} + marshaller := &marshallerMock.MarshalizerMock{} + + trie := &trieMock.TrieStub{ + UpdateCalled: func(key, value []byte) error { + return nil + }, + GetCalled: func(key []byte) ([]byte, uint32, error) { + return nil, 0, nil + }, + } + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: false, + } + tdt, _ := state.NewTrackableDataTrie( + identifier, + trie, + hasher, + marshaller, + enableEpochsHandler, + ) + assert.NotNil(t, tdt) + + valRecovered, _, err := tdt.RetrieveValue(expectedKey) + assert.Nil(t, err) + assert.Equal(t, []byte(nil), valRecovered) + }) } -func TestTrackableDataTrie_SaveKeyValueShouldSaveOnlyInDirty(t *testing.T) { +func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Parallel() - identifier := []byte("identifier") - keyExpected := []byte("key") - value := []byte("value") - - trie := &trieMock.TrieStub{ - UpdateCalled: func(key, value []byte) error { - return nil - }, - GetCalled: func(_ []byte) ([]byte, uint32, error) { - assert.Fail(t, "should not have saved directly in the trie") - return nil, 0, nil - }, - } - mdaw := state.NewTrackableDataTrie(identifier, trie) - assert.NotNil(t, mdaw) - - _ = mdaw.SaveKeyValue(keyExpected, value) - - // test in dirty - retrievedVal, _, err := mdaw.RetrieveValue(keyExpected) - assert.Nil(t, err) - assert.Equal(t, value, retrievedVal) + t.Run("no dirty data", func(t *testing.T) { + t.Parallel() + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), nil, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + + oldValues, err := tdt.SaveDirtyData(&trieMock.TrieStub{}) + assert.Nil(t, err) + assert.Equal(t, 0, len(oldValues)) + }) + + t.Run("nil trie creates a new trie", func(t *testing.T) { + t.Parallel() + + recreateCalled := false + trie := &trieMock.TrieStub{ + RecreateCalled: func(root []byte) (common.Trie, error) { + recreateCalled = true + return &trieMock.TrieStub{ + GetCalled: func(_ []byte) ([]byte, uint32, error) { + return nil, 0, nil + }, + UpdateWithVersionCalled: func(_, _ []byte, _ core.TrieNodeVersion) error { + return nil + }, + }, nil + }, + } + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), nil, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + + key := []byte("key") + _ = tdt.SaveKeyValue(key, []byte("val")) + oldValues, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 1, len(oldValues)) + assert.Equal(t, key, oldValues[0].Key) + assert.Equal(t, []byte(nil), oldValues[0].Value) + assert.True(t, recreateCalled) + }) + + t.Run("present in trie as valWithAppendedData", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + expectedKey := []byte("key") + expectedVal := []byte("value") + value := append(expectedVal, expectedKey...) + value = append(value, identifier...) + hasher := &hashingMocks.HasherMock{} + marshaller := &marshallerMock.MarshalizerMock{} + deleteCalled := false + updateCalled := false + + trieVal := &dataTrieValue.TrieLeafData{ + Value: expectedVal, + Key: expectedKey, + Address: identifier, + } + serializedTrieVal, _ := marshaller.Marshal(trieVal) + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, expectedKey) { + return value, 0, nil + } + return nil, 0, nil + }, + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { + assert.Equal(t, hasher.Compute(string(expectedKey)), key) + assert.Equal(t, serializedTrieVal, value) + updateCalled = true + return nil + }, + DeleteCalled: func(key []byte) error { + assert.Equal(t, expectedKey, key) + deleteCalled = true + return nil + }, + } + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tdt, _ := state.NewTrackableDataTrie(identifier, trie, hasher, marshaller, enableEpochsHandler) + + _ = tdt.SaveKeyValue(expectedKey, expectedVal) + oldValues, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 1, len(oldValues)) + assert.Equal(t, expectedKey, oldValues[0].Key) + assert.Equal(t, value, oldValues[0].Value) + assert.True(t, deleteCalled) + assert.True(t, updateCalled) + }) + + t.Run("present in trie as valWithAppendedData and auto balancing disabled", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + expectedKey := []byte("key") + val := []byte("value") + expectedVal := append(val, expectedKey...) + expectedVal = append(expectedVal, identifier...) + hasher := &hashingMocks.HasherMock{} + marshaller := &marshallerMock.MarshalizerMock{} + updateCalled := false + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, expectedKey) { + return expectedVal, 0, nil + } + return nil, 0, nil + }, + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { + assert.Equal(t, expectedKey, key) + assert.Equal(t, expectedVal, value) + updateCalled = true + return nil + }, + DeleteCalled: func(key []byte) error { + assert.Fail(t, "this should not have been called") + return nil + }, + } + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: false, + } + tdt, _ := state.NewTrackableDataTrie(identifier, trie, hasher, marshaller, enableEpochsHandler) + + _ = tdt.SaveKeyValue(expectedKey, val) + oldValues, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 1, len(oldValues)) + assert.Equal(t, expectedKey, oldValues[0].Key) + assert.Equal(t, expectedVal, oldValues[0].Value) + assert.True(t, updateCalled) + }) + + t.Run("present in trie as valAsStruct", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + expectedKey := []byte("key") + newVal := []byte("value") + oldVal := []byte("old val") + hasher := &hashingMocks.HasherMock{} + marshaller := &marshallerMock.MarshalizerMock{} + updateCalled := false + + oldTrieVal := &dataTrieValue.TrieLeafData{ + Value: oldVal, + Key: expectedKey, + Address: identifier, + } + serializedOldTrieVal, _ := marshaller.Marshal(oldTrieVal) + + newTrieVal := &dataTrieValue.TrieLeafData{ + Value: newVal, + Key: expectedKey, + Address: identifier, + } + serializedNewTrieVal, _ := marshaller.Marshal(newTrieVal) + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(key, hasher.Compute(string(expectedKey))) { + return serializedOldTrieVal, 0, nil + } + return nil, 0, nil + }, + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { + assert.Equal(t, hasher.Compute(string(expectedKey)), key) + assert.Equal(t, serializedNewTrieVal, value) + updateCalled = true + return nil + }, + DeleteCalled: func(key []byte) error { + assert.Fail(t, "this delete should not have been called") + return nil + }, + } + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tdt, _ := state.NewTrackableDataTrie(identifier, trie, hasher, marshaller, enableEpochsHandler) + + _ = tdt.SaveKeyValue(expectedKey, newVal) + oldValues, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 1, len(oldValues)) + assert.Equal(t, hasher.Compute(string(expectedKey)), oldValues[0].Key) + assert.Equal(t, serializedOldTrieVal, oldValues[0].Value) + assert.True(t, updateCalled) + }) + + t.Run("not present in trie", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + expectedKey := []byte("key") + newVal := []byte("value") + hasher := &hashingMocks.HasherMock{} + marshaller := &marshallerMock.MarshalizerMock{} + updateCalled := false + + newTrieVal := &dataTrieValue.TrieLeafData{ + Value: newVal, + Key: expectedKey, + Address: identifier, + } + serializedNewTrieVal, _ := marshaller.Marshal(newTrieVal) + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + return nil, 0, nil + }, + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { + assert.Equal(t, hasher.Compute(string(expectedKey)), key) + assert.Equal(t, serializedNewTrieVal, value) + updateCalled = true + return nil + }, + DeleteCalled: func(key []byte) error { + assert.Fail(t, "this delete should not have been called") + return nil + }, + } + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tdt, _ := state.NewTrackableDataTrie(identifier, trie, hasher, marshaller, enableEpochsHandler) + + _ = tdt.SaveKeyValue(expectedKey, newVal) + oldValues, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 1, len(oldValues)) + assert.Equal(t, hasher.Compute(string(expectedKey)), oldValues[0].Key) + assert.Equal(t, []byte(nil), oldValues[0].Value) + assert.True(t, updateCalled) + }) + + t.Run("dirty data is reset", func(t *testing.T) { + t.Parallel() + + expectedKey := []byte("key") + val := []byte("value") + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + return nil, 0, nil + }, + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { + return nil + }, + } + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + + _ = tdt.SaveKeyValue(expectedKey, val) + _, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 0, len(tdt.DirtyData())) + }) + + t.Run("nil val autobalance disabled", func(t *testing.T) { + t.Parallel() + + expectedKey := []byte("key") + updateCalled := false + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + return []byte("value"), 0, nil + }, + DeleteCalled: func(key []byte) error { + assert.Equal(t, expectedKey, key) + updateCalled = true + return nil + }, + } + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + + _ = tdt.SaveKeyValue(expectedKey, nil) + _, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 0, len(tdt.DirtyData())) + assert.True(t, updateCalled) + }) + + t.Run("nil val and nil old val", func(t *testing.T) { + t.Parallel() + + expectedKey := []byte("key") + deleteCalled := false + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + return nil, 0, nil + }, + DeleteCalled: func(key []byte) error { + assert.Equal(t, expectedKey, key) + deleteCalled = true + return nil + }, + } + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + + _ = tdt.SaveKeyValue(expectedKey, nil) + _, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 0, len(tdt.DirtyData())) + assert.False(t, deleteCalled) + }) + + t.Run("nil val autobalance enabled, old val saved at hashedKey", func(t *testing.T) { + t.Parallel() + + hasher := &hashingMocks.HasherMock{} + expectedKey := []byte("key") + deleteCalled := false + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(hasher.Compute(string(expectedKey)), key) { + return []byte("value"), 0, nil + } + + return nil, 0, nil + }, + DeleteCalled: func(key []byte) error { + assert.Equal(t, hasher.Compute(string(expectedKey)), key) + deleteCalled = true + return nil + }, + } + + enableEpchs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs) + + _ = tdt.SaveKeyValue(expectedKey, nil) + _, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 0, len(tdt.DirtyData())) + assert.True(t, deleteCalled) + }) + + t.Run("nil val autobalance enabled, old val saved at key", func(t *testing.T) { + t.Parallel() + + expectedKey := []byte("key") + deleteCalled := 0 + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(expectedKey, key) { + return []byte("value"), 0, nil + } + + return nil, 0, nil + }, + DeleteCalled: func(key []byte) error { + assert.Equal(t, expectedKey, key) + deleteCalled++ + return nil + }, + } + + enableEpchs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs) + + _ = tdt.SaveKeyValue(expectedKey, nil) + _, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 0, len(tdt.DirtyData())) + assert.Equal(t, 1, deleteCalled) + }) + + t.Run("not present in trie - autobalance disabled", func(t *testing.T) { + t.Parallel() + + identifier := []byte("identifier") + expectedKey := []byte("key") + newVal := []byte("value") + valueWithMetadata := append(newVal, expectedKey...) + valueWithMetadata = append(valueWithMetadata, identifier...) + hasher := &hashingMocks.HasherMock{} + marshaller := &marshallerMock.MarshalizerMock{} + updateCalled := false + + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + return nil, 0, nil + }, + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { + assert.Equal(t, expectedKey, key) + assert.Equal(t, valueWithMetadata, value) + updateCalled = true + return nil + }, + DeleteCalled: func(key []byte) error { + assert.Fail(t, "this delete should not have been called") + return nil + }, + } + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: false, + } + tdt, _ := state.NewTrackableDataTrie( + identifier, + trie, + hasher, + marshaller, + enableEpochsHandler, + ) + + _ = tdt.SaveKeyValue(expectedKey, newVal) + oldValues, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 1, len(oldValues)) + assert.Equal(t, expectedKey, oldValues[0].Key) + assert.Equal(t, []byte(nil), oldValues[0].Value) + assert.True(t, updateCalled) + }) } -func TestTrackableDataTrie_SetAndGetDataTrie(t *testing.T) { +func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { t.Parallel() - trie := &trieMock.TrieStub{} - mdaw := state.NewTrackableDataTrie([]byte("identifier"), trie) - - newTrie := &trieMock.TrieStub{} - mdaw.SetDataTrie(newTrie) - assert.Equal(t, newTrie, mdaw.DataTrie()) + t.Run("nil trie", func(t *testing.T) { + t.Parallel() + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), nil, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: &trieMock.DataTrieMigratorStub{}, + } + err := tdt.MigrateDataTrieLeaves(args) + assert.Equal(t, state.ErrNilTrie, err) + }) + + t.Run("nil trie migrator", func(t *testing.T) { + t.Parallel() + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), &trieMock.TrieStub{}, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: nil, + } + err := tdt.MigrateDataTrieLeaves(args) + assert.Equal(t, errorsCommon.ErrNilTrieMigrator, err) + }) + + t.Run("CollectLeavesForMigrationFails", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected error") + tr := &trieMock.TrieStub{ + CollectLeavesForMigrationCalled: func(_ vmcommon.ArgsMigrateDataTrieLeaves) error { + return expectedErr + }, + } + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), tr, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: &trieMock.DataTrieMigratorStub{}, + } + err := tdt.MigrateDataTrieLeaves(args) + assert.Equal(t, expectedErr, err) + }) + + t.Run("leaves that need to be migrated are added to dirty data", func(t *testing.T) { + t.Parallel() + + leavesToBeMigrated := []core.TrieData{ + { + Key: []byte("key1"), + Value: []byte("value1"), + Version: core.NotSpecified, + }, + { + Key: []byte("key2"), + Value: []byte("value2"), + Version: core.NotSpecified, + }, + { + Key: []byte("key3"), + Value: []byte("value3"), + Version: core.NotSpecified, + }, + } + tr := &trieMock.TrieStub{ + CollectLeavesForMigrationCalled: func(_ vmcommon.ArgsMigrateDataTrieLeaves) error { + return nil + }, + } + dtm := &trieMock.DataTrieMigratorStub{ + GetLeavesToBeMigratedCalled: func() []core.TrieData { + return leavesToBeMigrated + }, + } + enableEpchs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), tr, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs) + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: 100, + TrieMigrator: dtm, + } + err := tdt.MigrateDataTrieLeaves(args) + assert.Nil(t, err) + + dirtyData := tdt.DirtyData() + assert.Equal(t, len(leavesToBeMigrated), len(dirtyData)) + for i := range leavesToBeMigrated { + d := dirtyData[string(leavesToBeMigrated[i].Key)] + assert.Equal(t, leavesToBeMigrated[i].Value, d.Value) + assert.Equal(t, core.TrieNodeVersion(100), d.NewVersion) + } + }) } -func TestTrackableDataTrie_SaveKeyValueTooBig(t *testing.T) { +func TestTrackableDataTrie_SetAndGetDataTrie(t *testing.T) { t.Parallel() - identifier := []byte("identifier") trie := &trieMock.TrieStub{} - tdaw := state.NewTrackableDataTrie(identifier, trie) + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) - err := tdaw.SaveKeyValue([]byte("key"), make([]byte, core.MaxLeafSize+1)) - assert.Equal(t, err, data.ErrLeafSizeTooBig) + newTrie := &trieMock.TrieStub{} + tdt.SetDataTrie(newTrie) + assert.Equal(t, newTrie, tdt.DataTrie()) } diff --git a/state/userAccount.go b/state/userAccount.go index 72ff86cd306..0cfbec13a6d 100644 --- a/state/userAccount.go +++ b/state/userAccount.go @@ -3,8 +3,15 @@ package state import ( "bytes" + "context" "math/big" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/state/parsers" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) @@ -14,40 +21,97 @@ var _ UserAccountHandler = (*userAccount)(nil) type userAccount struct { *baseAccount UserAccountData + marshaller marshal.Marshalizer + enableEpochsHandler common.EnableEpochsHandler } var zero = big.NewInt(0) -// NewEmptyUserAccount creates new simple account wrapper for an AccountContainer (that has just been initialized) -func NewEmptyUserAccount() *userAccount { - return &userAccount{ - baseAccount: &baseAccount{}, - UserAccountData: UserAccountData{ - DeveloperReward: big.NewInt(0), - Balance: big.NewInt(0), - }, - } +// ArgsAccountCreation holds the arguments needed to create a new instance of userAccount +type ArgsAccountCreation struct { + Hasher hashing.Hasher + Marshaller marshal.Marshalizer + EnableEpochsHandler common.EnableEpochsHandler } -// NewUserAccount creates new simple account wrapper for an AccountContainer (that has just been initialized) -func NewUserAccount(address []byte) (*userAccount, error) { +// NewUserAccount creates a new instance of userAccount +func NewUserAccount( + address []byte, + args ArgsAccountCreation, +) (*userAccount, error) { if len(address) == 0 { return nil, ErrNilAddress } + err := checkArgs(args) + if err != nil { + return nil, err + } + + tdt, err := NewTrackableDataTrie(address, nil, args.Hasher, args.Marshaller, args.EnableEpochsHandler) + if err != nil { + return nil, err + } return &userAccount{ baseAccount: &baseAccount{ address: address, - dataTrieTracker: NewTrackableDataTrie(address, nil), + dataTrieTracker: tdt, }, UserAccountData: UserAccountData{ DeveloperReward: big.NewInt(0), Balance: big.NewInt(0), Address: address, }, + marshaller: args.Marshaller, + enableEpochsHandler: args.EnableEpochsHandler, }, nil } +// NewUserAccountFromBytes creates a new instance of userAccount from the given bytes +func NewUserAccountFromBytes( + accountBytes []byte, + args ArgsAccountCreation, +) (*userAccount, error) { + err := checkArgs(args) + if err != nil { + return nil, err + } + + acc := &userAccount{} + err = args.Marshaller.Unmarshal(acc, accountBytes) + if err != nil { + return nil, err + } + + tdt, err := NewTrackableDataTrie(acc.Address, nil, args.Hasher, args.Marshaller, args.EnableEpochsHandler) + if err != nil { + return nil, err + } + + acc.baseAccount = &baseAccount{ + address: acc.Address, + dataTrieTracker: tdt, + } + acc.marshaller = args.Marshaller + acc.enableEpochsHandler = args.EnableEpochsHandler + + return acc, nil +} + +func checkArgs(args ArgsAccountCreation) error { + if check.IfNil(args.Marshaller) { + return ErrNilMarshalizer + } + if check.IfNil(args.Hasher) { + return ErrNilHasher + } + if check.IfNil(args.EnableEpochsHandler) { + return ErrNilEnableEpochsHandler + } + + return nil +} + // SetUserName sets the users name func (a *userAccount) SetUserName(userName []byte) { a.UserName = make([]byte, 0, len(userName)) @@ -149,6 +213,39 @@ func (a *userAccount) IsGuarded() bool { return codeMetaData.Guarded } +// GetAllLeaves returns all the leaves of the account's data trie +func (a *userAccount) GetAllLeaves( + leavesChannels *common.TrieIteratorChannels, + ctx context.Context, +) error { + dataTrie := a.dataTrieTracker.DataTrie() + if check.IfNil(dataTrie) { + return ErrNilTrie + } + + rootHash, err := dataTrie.RootHash() + if err != nil { + return err + } + + tlp, err := parsers.NewDataTrieLeafParser(a.Address, a.marshaller, a.enableEpochsHandler) + if err != nil { + return err + } + + return dataTrie.GetAllLeavesOnChannel(leavesChannels, ctx, rootHash, keyBuilder.NewKeyBuilder(), tlp) +} + +// IsDataTrieMigrated returns true if the data trie is migrated to the latest version +func (a *userAccount) IsDataTrieMigrated() (bool, error) { + dt := a.dataTrieTracker.DataTrie() + if check.IfNil(dt) { + return false, ErrNilTrie + } + + return dt.IsMigratedToLatestVersion() +} + // IsInterfaceNil returns true if there is no value under the interface func (a *userAccount) IsInterfaceNil() bool { return a == nil diff --git a/state/userAccount_test.go b/state/userAccount_test.go index cb686d08f4e..b079dea3a56 100644 --- a/state/userAccount_test.go +++ b/state/userAccount_test.go @@ -1,34 +1,76 @@ package state_test import ( + "context" + "fmt" "math/big" + "strconv" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/stretchr/testify/assert" ) -func TestNewUserAccount_NilAddressContainerShouldErr(t *testing.T) { +func TestNewUserAccount(t *testing.T) { t.Parallel() - acc, err := state.NewUserAccount(nil) - assert.True(t, check.IfNil(acc)) - assert.Equal(t, state.ErrNilAddress, err) -} - -func TestNewUserAccount_OkParamsShouldWork(t *testing.T) { - t.Parallel() - - acc, err := state.NewUserAccount(make([]byte, 32)) - assert.Nil(t, err) - assert.False(t, check.IfNil(acc)) + t.Run("nil address", func(t *testing.T) { + t.Parallel() + + acc, err := state.NewUserAccount(nil, getDefaultArgsAccountCreation()) + assert.True(t, check.IfNil(acc)) + assert.Equal(t, state.ErrNilAddress, err) + }) + + t.Run("nil hasher", func(t *testing.T) { + t.Parallel() + + args := getDefaultArgsAccountCreation() + args.Hasher = nil + acc, err := state.NewUserAccount(make([]byte, 32), args) + assert.True(t, check.IfNil(acc)) + assert.Equal(t, state.ErrNilHasher, err) + }) + + t.Run("nil marshaller", func(t *testing.T) { + t.Parallel() + + args := getDefaultArgsAccountCreation() + args.Marshaller = nil + acc, err := state.NewUserAccount(make([]byte, 32), args) + assert.True(t, check.IfNil(acc)) + assert.Equal(t, state.ErrNilMarshalizer, err) + }) + + t.Run("nil enableEpochsHandler", func(t *testing.T) { + t.Parallel() + + args := getDefaultArgsAccountCreation() + args.EnableEpochsHandler = nil + acc, err := state.NewUserAccount(make([]byte, 32), args) + assert.True(t, check.IfNil(acc)) + assert.Equal(t, state.ErrNilEnableEpochsHandler, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + acc, err := state.NewUserAccount(make([]byte, 32), getDefaultArgsAccountCreation()) + assert.Nil(t, err) + assert.False(t, check.IfNil(acc)) + }) } func TestUserAccount_AddToBalanceInsufficientFundsShouldErr(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc := createUserAcc(make([]byte, 32)) value := big.NewInt(-1) err := acc.AddToBalance(value) @@ -38,7 +80,7 @@ func TestUserAccount_AddToBalanceInsufficientFundsShouldErr(t *testing.T) { func TestUserAccount_SubFromBalanceInsufficientFundsShouldErr(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc := createUserAcc(make([]byte, 32)) value := big.NewInt(1) err := acc.SubFromBalance(value) @@ -48,7 +90,7 @@ func TestUserAccount_SubFromBalanceInsufficientFundsShouldErr(t *testing.T) { func TestUserAccount_GetBalance(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc := createUserAcc(make([]byte, 32)) balance := big.NewInt(100) subFromBalance := big.NewInt(20) @@ -61,7 +103,7 @@ func TestUserAccount_GetBalance(t *testing.T) { func TestUserAccount_AddToDeveloperReward(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc := createUserAcc(make([]byte, 32)) reward := big.NewInt(10) acc.AddToDeveloperReward(reward) @@ -71,7 +113,7 @@ func TestUserAccount_AddToDeveloperReward(t *testing.T) { func TestUserAccount_ClaimDeveloperRewardsWrongAddressShouldErr(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc := createUserAcc(make([]byte, 32)) val, err := acc.ClaimDeveloperRewards([]byte("wrong address")) assert.Nil(t, val) assert.Equal(t, state.ErrOperationNotPermitted, err) @@ -80,7 +122,7 @@ func TestUserAccount_ClaimDeveloperRewardsWrongAddressShouldErr(t *testing.T) { func TestUserAccount_ClaimDeveloperRewards(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc, _ := state.NewUserAccount(make([]byte, 32), getDefaultArgsAccountCreation()) reward := big.NewInt(10) acc.AddToDeveloperReward(reward) @@ -93,7 +135,7 @@ func TestUserAccount_ClaimDeveloperRewards(t *testing.T) { func TestUserAccount_ChangeOwnerAddressWrongAddressShouldErr(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc := createUserAcc(make([]byte, 32)) err := acc.ChangeOwnerAddress([]byte("wrong address"), []byte{}) assert.Equal(t, state.ErrOperationNotPermitted, err) } @@ -101,7 +143,7 @@ func TestUserAccount_ChangeOwnerAddressWrongAddressShouldErr(t *testing.T) { func TestUserAccount_ChangeOwnerAddressInvalidAddressShouldErr(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc, _ := state.NewUserAccount(make([]byte, 32), getDefaultArgsAccountCreation()) err := acc.ChangeOwnerAddress(acc.OwnerAddress, []byte("new address")) assert.Equal(t, state.ErrInvalidAddressLength, err) } @@ -110,7 +152,7 @@ func TestUserAccount_ChangeOwnerAddress(t *testing.T) { t.Parallel() newAddress := make([]byte, 32) - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc, _ := state.NewUserAccount(make([]byte, 32), getDefaultArgsAccountCreation()) err := acc.ChangeOwnerAddress(acc.OwnerAddress, newAddress) assert.Nil(t, err) @@ -121,7 +163,7 @@ func TestUserAccount_SetOwnerAddress(t *testing.T) { t.Parallel() newAddress := []byte("new address") - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc := createUserAcc(make([]byte, 32)) acc.SetOwnerAddress(newAddress) assert.Equal(t, newAddress, acc.GetOwnerAddress()) @@ -130,7 +172,7 @@ func TestUserAccount_SetOwnerAddress(t *testing.T) { func TestUserAccount_SetAndGetNonce(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc := createUserAcc(make([]byte, 32)) nonce := uint64(5) acc.IncreaseNonce(nonce) @@ -140,7 +182,7 @@ func TestUserAccount_SetAndGetNonce(t *testing.T) { func TestUserAccount_SetAndGetCodeHash(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc := createUserAcc(make([]byte, 32)) codeHash := []byte("code hash") acc.SetCodeHash(codeHash) @@ -150,9 +192,188 @@ func TestUserAccount_SetAndGetCodeHash(t *testing.T) { func TestUserAccount_SetAndGetRootHash(t *testing.T) { t.Parallel() - acc, _ := state.NewUserAccount(make([]byte, 32)) + acc := createUserAcc(make([]byte, 32)) rootHash := []byte("root hash") acc.SetRootHash(rootHash) assert.Equal(t, rootHash, acc.GetRootHash()) } + +func TestUserAccount_GetAllLeaves(t *testing.T) { + t.Parallel() + + t.Run("autoBalance data tries disabled", func(t *testing.T) { + t.Parallel() + + tr, _ := getDefaultTrieAndAccountsDb() + acc, _ := state.NewUserAccount([]byte("address"), getDefaultArgsAccountCreation()) + numKeys := 1000 + vals := make(map[string][]byte) + for i := 0; i < numKeys; i++ { + key := []byte(strconv.Itoa(i)) + val := []byte(strconv.Itoa(i)) + vals[string(key)] = val + err := acc.SaveKeyValue(key, val) + assert.Nil(t, err) + } + acc.SetDataTrie(tr) + _, _ = acc.SaveDirtyData(tr) + rh, _ := acc.DataTrie().RootHash() + acc.SetRootHash(rh) + _ = tr.Commit() + + chLeaves := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, 100), + ErrChan: errChan.NewErrChanWrapper(), + } + err := acc.GetAllLeaves(chLeaves, context.Background()) + assert.Nil(t, err) + + for leaf := range chLeaves.LeavesChan { + val, ok := vals[string(leaf.Key())] + assert.True(t, ok) + assert.Equal(t, val, leaf.Value()) + } + }) + + t.Run("autoBalance data tries enabled", func(t *testing.T) { + t.Parallel() + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tr, _ := getDefaultTrieAndAccountsDb() + args := getDefaultArgsAccountCreation() + args.EnableEpochsHandler = enableEpochsHandler + acc, _ := state.NewUserAccount([]byte("address"), args) + numKeys := 1000 + vals := make(map[string][]byte) + for i := 0; i < numKeys; i++ { + key := []byte(strconv.Itoa(i)) + val := []byte(strconv.Itoa(i)) + vals[string(key)] = val + err := acc.SaveKeyValue(key, val) + assert.Nil(t, err) + } + acc.SetDataTrie(tr) + _, _ = acc.SaveDirtyData(tr) + rh, _ := acc.DataTrie().RootHash() + acc.SetRootHash(rh) + _ = tr.Commit() + + chLeaves := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, 100), + ErrChan: errChan.NewErrChanWrapper(), + } + err := acc.GetAllLeaves(chLeaves, context.Background()) + assert.Nil(t, err) + + for leaf := range chLeaves.LeavesChan { + val, ok := vals[string(leaf.Key())] + assert.True(t, ok) + assert.Equal(t, val, leaf.Value()) + } + }) + + t.Run("autoBalance data tries enabled after insert", func(t *testing.T) { + t.Parallel() + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: false, + } + tr, _ := getDefaultTrieAndAccountsDb() + args := getDefaultArgsAccountCreation() + args.EnableEpochsHandler = enableEpochsHandler + acc, _ := state.NewUserAccount([]byte("address"), args) + numKeys := 1000 + vals := make(map[string][]byte) + for i := 0; i < numKeys; i++ { + key := []byte(strconv.Itoa(i)) + val := []byte(strconv.Itoa(i)) + vals[string(key)] = val + err := acc.SaveKeyValue(key, val) + assert.Nil(t, err) + } + acc.SetDataTrie(tr) + _, _ = acc.SaveDirtyData(tr) + rh, _ := acc.DataTrie().RootHash() + acc.SetRootHash(rh) + _ = tr.Commit() + enableEpochsHandler.IsAutoBalanceDataTriesEnabledField = true + + chLeaves := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, 100), + ErrChan: errChan.NewErrChanWrapper(), + } + err := acc.GetAllLeaves(chLeaves, context.Background()) + assert.Nil(t, err) + + for leaf := range chLeaves.LeavesChan { + val, ok := vals[string(leaf.Key())] + assert.True(t, ok) + assert.Equal(t, val, leaf.Value()) + } + }) +} + +func TestUserAccount_IsDataTrieMigrated(t *testing.T) { + t.Parallel() + + t.Run("nil trie", func(t *testing.T) { + t.Parallel() + + acc, _ := state.NewUserAccount([]byte("address"), getDefaultArgsAccountCreation()) + isMigrated, err := acc.IsDataTrieMigrated() + assert.False(t, isMigrated) + assert.Equal(t, state.ErrNilTrie, err) + }) + + t.Run("trie is not migrated", func(t *testing.T) { + t.Parallel() + + acc, _ := state.NewUserAccount([]byte("address"), getDefaultArgsAccountCreation()) + acc.SetDataTrie( + &trie.TrieStub{ + IsMigratedToLatestVersionCalled: func() (bool, error) { + return false, nil + }, + }, + ) + isMigrated, err := acc.IsDataTrieMigrated() + assert.False(t, isMigrated) + assert.Nil(t, err) + }) + + t.Run("trie is migrated", func(t *testing.T) { + t.Parallel() + + acc, _ := state.NewUserAccount([]byte("address"), getDefaultArgsAccountCreation()) + acc.SetDataTrie( + &trie.TrieStub{ + IsMigratedToLatestVersionCalled: func() (bool, error) { + return true, nil + }, + }, + ) + isMigrated, err := acc.IsDataTrieMigrated() + assert.True(t, isMigrated) + assert.Nil(t, err) + }) + + t.Run("trie is migrated error", func(t *testing.T) { + t.Parallel() + + expectedErr := fmt.Errorf("expected error") + acc, _ := state.NewUserAccount([]byte("address"), getDefaultArgsAccountCreation()) + acc.SetDataTrie( + &trie.TrieStub{ + IsMigratedToLatestVersionCalled: func() (bool, error) { + return false, expectedErr + }, + }, + ) + isMigrated, err := acc.IsDataTrieMigrated() + assert.False(t, isMigrated) + assert.Equal(t, expectedErr, err) + }) +} diff --git a/storage/database/db.go b/storage/database/db.go index be5739dd41d..7e677ed954c 100644 --- a/storage/database/db.go +++ b/storage/database/db.go @@ -4,6 +4,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-storage-go/leveldb" "github.com/multiversx/mx-chain-storage-go/memorydb" + "github.com/multiversx/mx-chain-storage-go/sharded" ) // MemDB represents the memory database storage. It holds a map of key value pairs @@ -31,3 +32,13 @@ func NewLevelDB(path string, batchDelaySeconds int, maxBatchSize int, maxOpenFil func NewSerialDB(path string, batchDelaySeconds int, maxBatchSize int, maxOpenFiles int) (s *leveldb.SerialDB, err error) { return leveldb.NewSerialDB(path, batchDelaySeconds, maxBatchSize, maxOpenFiles) } + +// NewShardIDProvider is a constructor for shard id provider +func NewShardIDProvider(numShards int32) (storage.ShardIDProvider, error) { + return sharded.NewShardIDProvider(numShards) +} + +// NewShardedPersister is a constructor for sharded persister based on provided db type +func NewShardedPersister(path string, persisterCreator storage.PersisterCreator, idPersister storage.ShardIDProvider) (s storage.Persister, err error) { + return sharded.NewShardedPersister(path, persisterCreator, idPersister) +} diff --git a/storage/disabled/shardIDProvider.go b/storage/disabled/shardIDProvider.go new file mode 100644 index 00000000000..b997230d0f1 --- /dev/null +++ b/storage/disabled/shardIDProvider.go @@ -0,0 +1,28 @@ +package disabled + +type shardIDProvider struct{} + +// NewShardIDProvider returns a new disabled shard id provider instance +func NewShardIDProvider() *shardIDProvider { + return &shardIDProvider{} +} + +// ComputeId returns 0 +func (s *shardIDProvider) ComputeId(key []byte) uint32 { + return 0 +} + +// NumberOfShards returns 0 +func (s *shardIDProvider) NumberOfShards() uint32 { + return 0 +} + +// GetShardIDs returns nil +func (s *shardIDProvider) GetShardIDs() []uint32 { + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (s *shardIDProvider) IsInterfaceNil() bool { + return s == nil +} diff --git a/storage/disabled/shardIDProvider_test.go b/storage/disabled/shardIDProvider_test.go new file mode 100644 index 00000000000..06ddc7584d6 --- /dev/null +++ b/storage/disabled/shardIDProvider_test.go @@ -0,0 +1,26 @@ +package disabled + +import ( + "fmt" + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" +) + +func TestShardIDProvider_MethodsDoNotPanic(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r != nil { + assert.Fail(t, fmt.Sprintf("should have not panicked: %v", r)) + } + }() + + s := NewShardIDProvider() + assert.Zero(t, s.ComputeId([]byte{})) + assert.Zero(t, s.NumberOfShards()) + assert.Nil(t, s.GetShardIDs()) + assert.False(t, check.IfNil(s)) +} diff --git a/storage/errors.go b/storage/errors.go index 14c62db42ac..16e83d927fa 100644 --- a/storage/errors.go +++ b/storage/errors.go @@ -88,6 +88,18 @@ var ErrEpochKeepIsLowerThanNumActive = errors.New("num epochs to keep is lower t // ErrNilPersistersTracker signals that a nil persisters tracker has been provided var ErrNilPersistersTracker = errors.New("nil persisters tracker provided") +// ErrNilShardIDProvider signals that a nil shard id provider has been provided +var ErrNilShardIDProvider = errors.New("nil shard id provider") + +// ErrNotSupportedShardIDProviderType is raised when an unsupported shard id provider type is provided +var ErrNotSupportedShardIDProviderType = errors.New("invalid shard id provider type has been provided") + +// ErrInvalidFilePath signals that an invalid file path has been provided +var ErrInvalidFilePath = errors.New("invalid file path") + +// ErrNilDBConfigHandler signals that a nil db config handler has been provided +var ErrNilDBConfigHandler = errors.New("nil db config handler") + // ErrNilManagedPeersHolder signals that a nil managed peers holder has been provided var ErrNilManagedPeersHolder = errors.New("nil managed peers holder") diff --git a/storage/factory/dbConfigHandler.go b/storage/factory/dbConfigHandler.go new file mode 100644 index 00000000000..a2e2797f3a6 --- /dev/null +++ b/storage/factory/dbConfigHandler.go @@ -0,0 +1,141 @@ +package factory + +import ( + "io/ioutil" + "os" + "path/filepath" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/config" +) + +const ( + dbConfigFileName = "config.toml" + defaultType = "LvlDBSerial" + defaultBatchDelaySeconds = 2 + defaultMaxBatchSize = 100 + defaultMaxOpenFiles = 10 +) + +type dbConfigHandler struct { + dbType string + batchDelaySeconds int + maxBatchSize int + maxOpenFiles int + shardIDProviderType string + numShards int32 +} + +// NewDBConfigHandler will create a new db config handler instance +func NewDBConfigHandler(config config.DBConfig) *dbConfigHandler { + return &dbConfigHandler{ + dbType: config.Type, + batchDelaySeconds: config.BatchDelaySeconds, + maxBatchSize: config.MaxBatchSize, + maxOpenFiles: config.MaxOpenFiles, + shardIDProviderType: config.ShardIDProviderType, + numShards: config.NumShards, + } +} + +// GetDBConfig will get the db config based on path +func (dh *dbConfigHandler) GetDBConfig(path string) (*config.DBConfig, error) { + dbConfigFromFile := &config.DBConfig{} + err := core.LoadTomlFile(dbConfigFromFile, getPersisterConfigFilePath(path)) + if err == nil { + log.Debug("GetDBConfig: loaded db config from toml config file", "path", dbConfigFromFile) + return dbConfigFromFile, nil + } + + empty := checkIfDirIsEmpty(path) + if !empty { + dbConfig := &config.DBConfig{ + Type: defaultType, + BatchDelaySeconds: defaultBatchDelaySeconds, + MaxBatchSize: defaultMaxBatchSize, + MaxOpenFiles: defaultMaxOpenFiles, + } + + log.Debug("GetDBConfig: loaded default db config") + return dbConfig, nil + } + + dbConfig := &config.DBConfig{ + Type: dh.dbType, + BatchDelaySeconds: dh.batchDelaySeconds, + MaxBatchSize: dh.maxBatchSize, + MaxOpenFiles: dh.maxOpenFiles, + ShardIDProviderType: dh.shardIDProviderType, + NumShards: dh.numShards, + } + + log.Debug("GetDBConfig: loaded db config from main config file") + return dbConfig, nil +} + +// SaveDBConfigToFilePath will save the provided db config to specified path +func (dh *dbConfigHandler) SaveDBConfigToFilePath(path string, dbConfig *config.DBConfig) error { + pathExists, err := checkIfDirExists(path) + if err != nil { + return err + } + if !pathExists { + // in memory db, no files available + return nil + } + + configFilePath := getPersisterConfigFilePath(path) + + loadedDBConfig := &config.DBConfig{} + err = core.LoadTomlFile(loadedDBConfig, configFilePath) + if err == nil { + // config file already exists, no need to save config + return nil + } + + err = core.SaveTomlFile(dbConfig, configFilePath) + if err != nil { + return err + } + + return nil +} + +func getPersisterConfigFilePath(path string) string { + return filepath.Join( + path, + dbConfigFileName, + ) +} + +func checkIfDirExists(path string) (bool, error) { + _, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + + return false, err + } + + return true, nil +} + +func checkIfDirIsEmpty(path string) bool { + files, err := ioutil.ReadDir(path) + if err != nil { + log.Trace("getDBConfig: failed to check if dir is empty", "path", path, "error", err.Error()) + return true + } + + if len(files) == 0 { + return true + } + + return false +} + +// IsInterfaceNil returns true if there is no value under the interface +func (dh *dbConfigHandler) IsInterfaceNil() bool { + return dh == nil +} diff --git a/storage/factory/dbConfigHandler_test.go b/storage/factory/dbConfigHandler_test.go new file mode 100644 index 00000000000..406218be7dc --- /dev/null +++ b/storage/factory/dbConfigHandler_test.go @@ -0,0 +1,167 @@ +package factory_test + +import ( + "os" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/storage/factory" + "github.com/stretchr/testify/require" +) + +func createDefaultDBConfig() config.DBConfig { + return config.DBConfig{ + Type: "LvlDBSerial", + BatchDelaySeconds: 2, + MaxBatchSize: 100, + MaxOpenFiles: 10, + UseTmpAsFilePath: false, + ShardIDProviderType: "BinarySplit", + NumShards: 4, + } +} + +func TestDBConfigHandler_GetDBConfig(t *testing.T) { + t.Parallel() + + t.Run("load db config from toml config file", func(t *testing.T) { + t.Parallel() + + pf := factory.NewDBConfigHandler(createDefaultDBConfig()) + + dirPath := t.TempDir() + configPath := factory.GetPersisterConfigFilePath(dirPath) + + expectedDBConfig := config.DBConfig{ + FilePath: "filepath1", + Type: "type1", + BatchDelaySeconds: 1, + MaxBatchSize: 2, + MaxOpenFiles: 3, + NumShards: 4, + } + + err := core.SaveTomlFile(expectedDBConfig, configPath) + require.Nil(t, err) + + conf, err := pf.GetDBConfig(dirPath) + require.Nil(t, err) + require.Equal(t, &expectedDBConfig, conf) + }) + + t.Run("not empty dir, load default db config", func(t *testing.T) { + t.Parallel() + + pf := factory.NewDBConfigHandler(createDefaultDBConfig()) + + dirPath := t.TempDir() + + f, err := core.CreateFile(core.ArgCreateFileArgument{ + Directory: dirPath, + Prefix: "test", + FileExtension: "log", + }) + require.Nil(t, err) + + defer func() { + _ = f.Close() + }() + + expectedDBConfig := factory.GetDefaultDBConfig() + + conf, err := pf.GetDBConfig(dirPath) + require.Nil(t, err) + require.Equal(t, expectedDBConfig, conf) + }) + + t.Run("empty dir, load db config from main config", func(t *testing.T) { + t.Parallel() + + expectedDBConfig := createDefaultDBConfig() + + pf := factory.NewDBConfigHandler(createDefaultDBConfig()) + + dirPath := t.TempDir() + + conf, err := pf.GetDBConfig(dirPath) + require.Nil(t, err) + require.Equal(t, &expectedDBConfig, conf) + }) + + t.Run("getDBConfig twice, should load from config file if file available", func(t *testing.T) { + t.Parallel() + + expectedDBConfig := createDefaultDBConfig() + + dbConfigHandler := factory.NewDBConfigHandler(createDefaultDBConfig()) + + dirPath := t.TempDir() + + conf, err := dbConfigHandler.GetDBConfig(dirPath) + require.Nil(t, err) + require.Equal(t, &expectedDBConfig, conf) + + newDBConfig := config.DBConfig{ + Type: "type1", + BatchDelaySeconds: 1, + MaxBatchSize: 2, + MaxOpenFiles: 3, + NumShards: 4, + } + + configPath := factory.GetPersisterConfigFilePath(dirPath) + + err = core.SaveTomlFile(expectedDBConfig, configPath) + require.Nil(t, err) + + dbConfigHandler = factory.NewDBConfigHandler(newDBConfig) + conf, err = dbConfigHandler.GetDBConfig(dirPath) + require.Nil(t, err) + require.Equal(t, &expectedDBConfig, conf) + }) +} + +func TestDBConfigHandler_SaveDBConfigToFilePath(t *testing.T) { + t.Parallel() + + t.Run("no path present, in memory db, should not fail", func(t *testing.T) { + t.Parallel() + + dbConfig := createDefaultDBConfig() + + pf := factory.NewDBConfigHandler(dbConfig) + err := pf.SaveDBConfigToFilePath("no/valid/path", &dbConfig) + require.Nil(t, err) + }) + + t.Run("config file already present, should not fail", func(t *testing.T) { + t.Parallel() + + dbConfig := createDefaultDBConfig() + dirPath := t.TempDir() + configPath := factory.GetPersisterConfigFilePath(dirPath) + + err := core.SaveTomlFile(dbConfig, configPath) + require.Nil(t, err) + + pf := factory.NewDBConfigHandler(dbConfig) + err = pf.SaveDBConfigToFilePath(dirPath, &dbConfig) + require.Nil(t, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + dbConfig := createDefaultDBConfig() + dirPath := t.TempDir() + + pf := factory.NewDBConfigHandler(dbConfig) + err := pf.SaveDBConfigToFilePath(dirPath, &dbConfig) + require.Nil(t, err) + + configPath := factory.GetPersisterConfigFilePath(dirPath) + _, err = os.Stat(configPath) + require.False(t, os.IsNotExist(err)) + }) +} diff --git a/storage/factory/export_test.go b/storage/factory/export_test.go new file mode 100644 index 00000000000..4b5ac54baac --- /dev/null +++ b/storage/factory/export_test.go @@ -0,0 +1,31 @@ +package factory + +import ( + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/storage" +) + +// GetPersisterConfigFilePath - +func GetPersisterConfigFilePath(path string) string { + return getPersisterConfigFilePath(path) +} + +// GetDefaultDBConfig - +func GetDefaultDBConfig() *config.DBConfig { + return &config.DBConfig{ + Type: defaultType, + BatchDelaySeconds: defaultBatchDelaySeconds, + MaxBatchSize: defaultMaxBatchSize, + MaxOpenFiles: defaultMaxOpenFiles, + } +} + +// NewPersisterCreator - +func NewPersisterCreator(config config.DBConfig) *persisterCreator { + return newPersisterCreator(config) +} + +// CreateShardIDProvider - +func (pc *persisterCreator) CreateShardIDProvider() (storage.ShardIDProvider, error) { + return pc.createShardIDProvider() +} diff --git a/storage/factory/openStorage.go b/storage/factory/openStorage.go index 2f02327cc02..80dae5bc39c 100644 --- a/storage/factory/openStorage.go +++ b/storage/factory/openStorage.go @@ -56,7 +56,11 @@ func (o *openStorageUnits) GetMostRecentStorageUnit(dbConfig config.DBConfig) (s return nil, err } - persisterFactory := NewPersisterFactory(dbConfig) + dbConfigHandler := NewDBConfigHandler(dbConfig) + persisterFactory, err := NewPersisterFactory(dbConfigHandler) + if err != nil { + return nil, err + } pathWithoutShard := o.getPathWithoutShard(parentDir, lastEpoch) shardIdsStr, err := o.latestStorageDataProvider.GetShardsFromDirectory(pathWithoutShard) if err != nil { @@ -108,7 +112,11 @@ func (o *openStorageUnits) OpenDB(dbConfig config.DBConfig, shardID uint32, epoc parentDir := o.latestStorageDataProvider.GetParentDirectory() pathWithoutShard := o.getPathWithoutShard(parentDir, epoch) persisterPath := o.getPersisterPath(pathWithoutShard, fmt.Sprintf("%d", shardID), dbConfig) - persisterFactory := NewPersisterFactory(dbConfig) + dbConfigHandler := NewDBConfigHandler(dbConfig) + persisterFactory, err := NewPersisterFactory(dbConfigHandler) + if err != nil { + return nil, err + } persister, err := createDB(persisterFactory, persisterPath) if err != nil { diff --git a/storage/factory/openStorage_test.go b/storage/factory/openStorage_test.go index 69a81bc1f67..1a1273df5f4 100644 --- a/storage/factory/openStorage_test.go +++ b/storage/factory/openStorage_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" "github.com/multiversx/mx-chain-go/storage" @@ -30,7 +31,7 @@ func TestNewStorageUnitOpenHandler(t *testing.T) { suoh, err := NewStorageUnitOpenHandler(createMockArgsOpenStorageUnits()) assert.NoError(t, err) - assert.NotNil(t, suoh) + assert.False(t, check.IfNil(suoh)) }) t.Run("nil BootstrapDataProvider should error", func(t *testing.T) { t.Parallel() diff --git a/storage/factory/persisterCreator.go b/storage/factory/persisterCreator.go new file mode 100644 index 00000000000..9c0a87bebf8 --- /dev/null +++ b/storage/factory/persisterCreator.go @@ -0,0 +1,77 @@ +package factory + +import ( + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/storage/database" + "github.com/multiversx/mx-chain-go/storage/storageunit" +) + +const minNumShards = 2 + +// persisterCreator is the factory which will handle creating new persisters +type persisterCreator struct { + dbType string + batchDelaySeconds int + maxBatchSize int + maxOpenFiles int + shardIDProviderType string + numShards int32 +} + +func newPersisterCreator(config config.DBConfig) *persisterCreator { + return &persisterCreator{ + dbType: config.Type, + batchDelaySeconds: config.BatchDelaySeconds, + maxBatchSize: config.MaxBatchSize, + maxOpenFiles: config.MaxOpenFiles, + shardIDProviderType: config.ShardIDProviderType, + numShards: config.NumShards, + } +} + +// Create will create the persister for the provided path +func (pc *persisterCreator) Create(path string) (storage.Persister, error) { + if len(path) == 0 { + return nil, storage.ErrInvalidFilePath + } + + if pc.numShards < minNumShards { + return pc.CreateBasePersister(path) + } + + shardIDProvider, err := pc.createShardIDProvider() + if err != nil { + return nil, err + } + return database.NewShardedPersister(path, pc, shardIDProvider) +} + +// CreateBasePersister will create base the persister for the provided path +func (pc *persisterCreator) CreateBasePersister(path string) (storage.Persister, error) { + var dbType = storageunit.DBType(pc.dbType) + switch dbType { + case storageunit.LvlDB: + return database.NewLevelDB(path, pc.batchDelaySeconds, pc.maxBatchSize, pc.maxOpenFiles) + case storageunit.LvlDBSerial: + return database.NewSerialDB(path, pc.batchDelaySeconds, pc.maxBatchSize, pc.maxOpenFiles) + case storageunit.MemoryDB: + return database.NewMemDB(), nil + default: + return nil, storage.ErrNotSupportedDBType + } +} + +func (pc *persisterCreator) createShardIDProvider() (storage.ShardIDProvider, error) { + switch storageunit.ShardIDProviderType(pc.shardIDProviderType) { + case storageunit.BinarySplit: + return database.NewShardIDProvider(pc.numShards) + default: + return nil, storage.ErrNotSupportedShardIDProviderType + } +} + +// IsInterfaceNil returns true if there is no value under the interface +func (pc *persisterCreator) IsInterfaceNil() bool { + return pc == nil +} diff --git a/storage/factory/persisterCreator_test.go b/storage/factory/persisterCreator_test.go new file mode 100644 index 00000000000..a0fdef7e1ef --- /dev/null +++ b/storage/factory/persisterCreator_test.go @@ -0,0 +1,155 @@ +package factory_test + +import ( + "fmt" + "strings" + "testing" + + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/storage/factory" + "github.com/multiversx/mx-chain-go/storage/storageunit" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func createDefaultBasePersisterConfig() config.DBConfig { + return config.DBConfig{ + Type: "LvlDBSerial", + BatchDelaySeconds: 2, + MaxBatchSize: 100, + MaxOpenFiles: 10, + UseTmpAsFilePath: false, + ShardIDProviderType: "BinarySplit", + NumShards: 1, + } +} + +func TestPersisterCreator_Create(t *testing.T) { + t.Parallel() + + t.Run("invalid file path, should fail", func(t *testing.T) { + t.Parallel() + + pc := factory.NewPersisterCreator(createDefaultDBConfig()) + + p, err := pc.Create("") + require.Nil(t, p) + require.Equal(t, storage.ErrInvalidFilePath, err) + }) + + t.Run("should create non sharded persister", func(t *testing.T) { + t.Parallel() + + pc := factory.NewPersisterCreator(createDefaultBasePersisterConfig()) + + dir := t.TempDir() + p, err := pc.Create(dir) + require.NotNil(t, p) + require.Nil(t, err) + + assert.True(t, strings.Contains(fmt.Sprintf("%T", p), "*leveldb.SerialDB")) + }) + + t.Run("should create sharded persister", func(t *testing.T) { + t.Parallel() + + pc := factory.NewPersisterCreator(createDefaultDBConfig()) + + dir := t.TempDir() + p, err := pc.Create(dir) + require.NotNil(t, p) + require.Nil(t, err) + + assert.True(t, strings.Contains(fmt.Sprintf("%T", p), "*sharded.shardedPersister")) + }) +} + +func TestPersisterCreator_CreateBasePersister(t *testing.T) { + t.Parallel() + + t.Run("not supported type, should fail", func(t *testing.T) { + t.Parallel() + + dbConfig := createDefaultBasePersisterConfig() + dbConfig.Type = "not supported type" + pc := factory.NewPersisterCreator(dbConfig) + + dir := t.TempDir() + p, err := pc.CreateBasePersister(dir) + require.Nil(t, p) + require.Equal(t, storage.ErrNotSupportedDBType, err) + }) + + t.Run("leveldb", func(t *testing.T) { + t.Parallel() + + dbConfig := createDefaultBasePersisterConfig() + dbConfig.Type = string(storageunit.LvlDB) + pc := factory.NewPersisterCreator(dbConfig) + + dir := t.TempDir() + p, err := pc.CreateBasePersister(dir) + require.NotNil(t, p) + require.Nil(t, err) + + assert.True(t, strings.Contains(fmt.Sprintf("%T", p), "*leveldb.DB")) + }) + + t.Run("serial leveldb", func(t *testing.T) { + t.Parallel() + + pc := factory.NewPersisterCreator(createDefaultBasePersisterConfig()) + + dir := t.TempDir() + p, err := pc.CreateBasePersister(dir) + require.NotNil(t, p) + require.Nil(t, err) + + assert.True(t, strings.Contains(fmt.Sprintf("%T", p), "*leveldb.SerialDB")) + }) + + t.Run("memorydb", func(t *testing.T) { + t.Parallel() + + dbConfig := createDefaultBasePersisterConfig() + dbConfig.Type = string(storageunit.MemoryDB) + pc := factory.NewPersisterCreator(dbConfig) + + dir := t.TempDir() + p, err := pc.CreateBasePersister(dir) + require.NotNil(t, p) + require.Nil(t, err) + + assert.True(t, strings.Contains(fmt.Sprintf("%T", p), "*memorydb.DB")) + }) +} + +func TestPersisterCreator_CreateShardIDProvider(t *testing.T) { + t.Parallel() + + t.Run("not supported type, should fail", func(t *testing.T) { + t.Parallel() + + dbConfig := createDefaultDBConfig() + dbConfig.ShardIDProviderType = "not supported type" + pc := factory.NewPersisterCreator(dbConfig) + + p, err := pc.CreateShardIDProvider() + require.Nil(t, p) + require.Equal(t, storage.ErrNotSupportedShardIDProviderType, err) + }) + + t.Run("binary split, should work", func(t *testing.T) { + t.Parallel() + + dbConfig := createDefaultDBConfig() + pc := factory.NewPersisterCreator(dbConfig) + + p, err := pc.CreateShardIDProvider() + require.NotNil(t, p) + require.Nil(t, err) + + assert.True(t, strings.Contains(fmt.Sprintf("%T", p), "*sharded.shardIDProvider")) + }) +} diff --git a/storage/factory/persisterFactory.go b/storage/factory/persisterFactory.go index 55b3d45806a..a1305ec2184 100644 --- a/storage/factory/persisterFactory.go +++ b/storage/factory/persisterFactory.go @@ -1,49 +1,51 @@ package factory import ( - "errors" - - "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/disabled" - "github.com/multiversx/mx-chain-go/storage/storageunit" ) // PersisterFactory is the factory which will handle creating new databases type PersisterFactory struct { - dbType string - batchDelaySeconds int - maxBatchSize int - maxOpenFiles int + dbConfigHandler storage.DBConfigHandler } // NewPersisterFactory will return a new instance of a PersisterFactory -func NewPersisterFactory(config config.DBConfig) *PersisterFactory { - return &PersisterFactory{ - dbType: config.Type, - batchDelaySeconds: config.BatchDelaySeconds, - maxBatchSize: config.MaxBatchSize, - maxOpenFiles: config.MaxOpenFiles, +func NewPersisterFactory(dbConfigHandler storage.DBConfigHandler) (*PersisterFactory, error) { + if check.IfNil(dbConfigHandler) { + return nil, storage.ErrNilDBConfigHandler } + + return &PersisterFactory{ + dbConfigHandler: dbConfigHandler, + }, nil } // Create will return a new instance of a DB with a given path func (pf *PersisterFactory) Create(path string) (storage.Persister, error) { if len(path) == 0 { - return nil, errors.New("invalid file path") + return nil, storage.ErrInvalidFilePath + } + + dbConfig, err := pf.dbConfigHandler.GetDBConfig(path) + if err != nil { + return nil, err } - switch storageunit.DBType(pf.dbType) { - case storageunit.LvlDB: - return database.NewLevelDB(path, pf.batchDelaySeconds, pf.maxBatchSize, pf.maxOpenFiles) - case storageunit.LvlDBSerial: - return database.NewSerialDB(path, pf.batchDelaySeconds, pf.maxBatchSize, pf.maxOpenFiles) - case storageunit.MemoryDB: - return database.NewMemDB(), nil - default: - return nil, storage.ErrNotSupportedDBType + pc := newPersisterCreator(*dbConfig) + + persister, err := pc.Create(path) + if err != nil { + return nil, err } + + err = pf.dbConfigHandler.SaveDBConfigToFilePath(path, dbConfig) + if err != nil { + return nil, err + } + + return persister, nil } // CreateDisabled will return a new disabled persister diff --git a/storage/factory/persisterFactory_test.go b/storage/factory/persisterFactory_test.go index 3aee6a0132e..208542a665b 100644 --- a/storage/factory/persisterFactory_test.go +++ b/storage/factory/persisterFactory_test.go @@ -1,90 +1,144 @@ -package factory +package factory_test import ( "fmt" + "os" "testing" - "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/storage/factory" + "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func createDBConfig(dbType string) config.DBConfig { - return config.DBConfig{ - FilePath: "TEST", - Type: dbType, - BatchDelaySeconds: 5, - MaxBatchSize: 100, - MaxOpenFiles: 10, - UseTmpAsFilePath: false, - } -} - func TestNewPersisterFactory(t *testing.T) { t.Parallel() - factoryInstance := NewPersisterFactory(createDBConfig("LvlDB")) - assert.NotNil(t, factoryInstance) + dbConfigHandler := factory.NewDBConfigHandler(createDefaultDBConfig()) + pf, err := factory.NewPersisterFactory(dbConfigHandler) + require.NotNil(t, pf) + require.Nil(t, err) } func TestPersisterFactory_Create(t *testing.T) { t.Parallel() - t.Run("empty path should error", func(t *testing.T) { + t.Run("invalid file path, should fail", func(t *testing.T) { t.Parallel() - factoryInstance := NewPersisterFactory(createDBConfig("LvlDB")) - persisterInstance, err := factoryInstance.Create("") - assert.True(t, check.IfNil(persisterInstance)) - expectedErrString := "invalid file path" - assert.Equal(t, expectedErrString, err.Error()) + dbConfigHandler := factory.NewDBConfigHandler(createDefaultDBConfig()) + pf, _ := factory.NewPersisterFactory(dbConfigHandler) + + p, err := pf.Create("") + require.Nil(t, p) + require.Equal(t, storage.ErrInvalidFilePath, err) }) - t.Run("unknown type should error", func(t *testing.T) { + + t.Run("should work", func(t *testing.T) { t.Parallel() - factoryInstance := NewPersisterFactory(createDBConfig("invalid type")) - persisterInstance, err := factoryInstance.Create(t.TempDir()) - assert.True(t, check.IfNil(persisterInstance)) - assert.Equal(t, storage.ErrNotSupportedDBType, err) + dbConfigHandler := factory.NewDBConfigHandler(createDefaultDBConfig()) + pf, _ := factory.NewPersisterFactory(dbConfigHandler) + + dir := t.TempDir() + + p, err := pf.Create(dir) + require.NotNil(t, p) + require.Nil(t, err) }) - t.Run("for LvlDB should work", func(t *testing.T) { +} + +func TestPersisterFactory_Create_ConfigSaveToFilePath(t *testing.T) { + t.Parallel() + + t.Run("should write toml config file for leveldb", func(t *testing.T) { t.Parallel() - factoryInstance := NewPersisterFactory(createDBConfig("LvlDB")) - persisterInstance, err := factoryInstance.Create(t.TempDir()) - assert.Nil(t, err) - assert.False(t, check.IfNil(persisterInstance)) - assert.Equal(t, "*leveldb.DB", fmt.Sprintf("%T", persisterInstance)) - _ = persisterInstance.Close() + dbConfig := createDefaultBasePersisterConfig() + dbConfig.Type = string(storageunit.LvlDB) + dbConfigHandler := factory.NewDBConfigHandler(dbConfig) + pf, _ := factory.NewPersisterFactory(dbConfigHandler) + + dir := t.TempDir() + path := dir + "storer/" + + p, err := pf.Create(path) + require.NotNil(t, p) + require.Nil(t, err) + + configPath := factory.GetPersisterConfigFilePath(path) + _, err = os.Stat(configPath) + require.False(t, os.IsNotExist(err)) }) - t.Run("for LvlDBSerial should work", func(t *testing.T) { + + t.Run("should write toml config file for serial leveldb", func(t *testing.T) { t.Parallel() - factoryInstance := NewPersisterFactory(createDBConfig("LvlDBSerial")) - persisterInstance, err := factoryInstance.Create(t.TempDir()) - assert.Nil(t, err) - assert.False(t, check.IfNil(persisterInstance)) - assert.Equal(t, "*leveldb.SerialDB", fmt.Sprintf("%T", persisterInstance)) - _ = persisterInstance.Close() + dbConfig := createDefaultBasePersisterConfig() + dbConfig.Type = string(storageunit.LvlDBSerial) + dbConfigHandler := factory.NewDBConfigHandler(dbConfig) + pf, _ := factory.NewPersisterFactory(dbConfigHandler) + + dir := t.TempDir() + path := dir + "storer/" + + p, err := pf.Create(path) + require.NotNil(t, p) + require.Nil(t, err) + + configPath := factory.GetPersisterConfigFilePath(path) + _, err = os.Stat(configPath) + require.False(t, os.IsNotExist(err)) + }) + + t.Run("should not write toml config file for memory db", func(t *testing.T) { + t.Parallel() + + dbConfig := createDefaultBasePersisterConfig() + dbConfig.Type = string(storageunit.MemoryDB) + dbConfigHandler := factory.NewDBConfigHandler(dbConfig) + pf, _ := factory.NewPersisterFactory(dbConfigHandler) + + dir := t.TempDir() + path := dir + "storer/" + + p, err := pf.Create(path) + require.NotNil(t, p) + require.Nil(t, err) + + configPath := factory.GetPersisterConfigFilePath(path) + _, err = os.Stat(configPath) + require.True(t, os.IsNotExist(err)) }) - t.Run("for MemoryDB should work", func(t *testing.T) { + + t.Run("should not create path dir for memory db", func(t *testing.T) { t.Parallel() - factoryInstance := NewPersisterFactory(createDBConfig("MemoryDB")) - persisterInstance, err := factoryInstance.Create(t.TempDir()) - assert.Nil(t, err) - assert.False(t, check.IfNil(persisterInstance)) - assert.Equal(t, "*memorydb.DB", fmt.Sprintf("%T", persisterInstance)) - _ = persisterInstance.Close() + dbConfig := createDefaultBasePersisterConfig() + dbConfig.Type = string(storageunit.MemoryDB) + dbConfigHandler := factory.NewDBConfigHandler(dbConfig) + pf, _ := factory.NewPersisterFactory(dbConfigHandler) + + dir := t.TempDir() + path := dir + "storer/" + + p, err := pf.Create(path) + require.NotNil(t, p) + require.Nil(t, err) + + _, err = os.Stat(path) + require.True(t, os.IsNotExist(err)) }) } func TestPersisterFactory_CreateDisabled(t *testing.T) { t.Parallel() - factoryInstance := NewPersisterFactory(createDBConfig("LvlDB")) + dbConfigHandler := factory.NewDBConfigHandler(createDefaultDBConfig()) + factoryInstance, err := factory.NewPersisterFactory(dbConfigHandler) + require.Nil(t, err) + persisterInstance := factoryInstance.CreateDisabled() assert.NotNil(t, persisterInstance) assert.Equal(t, "*disabled.errorDisabledPersister", fmt.Sprintf("%T", persisterInstance)) @@ -93,9 +147,10 @@ func TestPersisterFactory_CreateDisabled(t *testing.T) { func TestPersisterFactory_IsInterfaceNil(t *testing.T) { t.Parallel() - var pf *PersisterFactory + var pf *factory.PersisterFactory require.True(t, pf.IsInterfaceNil()) - pf = NewPersisterFactory(config.DBConfig{}) + dbConfigHandler := factory.NewDBConfigHandler(createDefaultDBConfig()) + pf, _ = factory.NewPersisterFactory(dbConfigHandler) require.False(t, pf.IsInterfaceNil()) } diff --git a/storage/factory/storageServiceFactory.go b/storage/factory/storageServiceFactory.go index d1030f1a479..0416bd6b41b 100644 --- a/storage/factory/storageServiceFactory.go +++ b/storage/factory/storageServiceFactory.go @@ -132,56 +132,80 @@ func (psf *StorageServiceFactory) createAndAddBaseStorageUnits( ) error { disabledCustomDatabaseRemover := disabled.NewDisabledCustomDatabaseRemover() - txUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.TxStorage, disabledCustomDatabaseRemover) + txUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.TxStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } txUnit, err := psf.createPruningPersister(txUnitArgs) if err != nil { return fmt.Errorf("%w for TxStorage", err) } store.AddStorer(dataRetriever.TransactionUnit, txUnit) - unsignedTxUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.UnsignedTransactionStorage, disabledCustomDatabaseRemover) + unsignedTxUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.UnsignedTransactionStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } unsignedTxUnit, err := psf.createPruningPersister(unsignedTxUnitArgs) if err != nil { return fmt.Errorf("%w for UnsignedTransactionStorage", err) } store.AddStorer(dataRetriever.UnsignedTransactionUnit, unsignedTxUnit) - rewardTxUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.RewardTxStorage, disabledCustomDatabaseRemover) + rewardTxUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.RewardTxStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } rewardTxUnit, err := psf.createPruningPersister(rewardTxUnitArgs) if err != nil { return fmt.Errorf("%w for RewardTxStorage", err) } store.AddStorer(dataRetriever.RewardTransactionUnit, rewardTxUnit) - receiptsUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.ReceiptsStorage, disabledCustomDatabaseRemover) + receiptsUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.ReceiptsStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } receiptsUnit, err := psf.createPruningPersister(receiptsUnitArgs) if err != nil { return fmt.Errorf("%w for ReceiptsStorage", err) } store.AddStorer(dataRetriever.ReceiptsUnit, receiptsUnit) - scheduledSCRsUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.ScheduledSCRsStorage, disabledCustomDatabaseRemover) + scheduledSCRsUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.ScheduledSCRsStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } scheduledSCRsUnit, err := psf.createPruningPersister(scheduledSCRsUnitArgs) if err != nil { return fmt.Errorf("%w for ScheduledSCRsStorage", err) } store.AddStorer(dataRetriever.ScheduledSCRsUnit, scheduledSCRsUnit) - bootstrapUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.BootstrapStorage, disabledCustomDatabaseRemover) + bootstrapUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.BootstrapStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } bootstrapUnit, err := psf.createPruningPersister(bootstrapUnitArgs) if err != nil { return fmt.Errorf("%w for BootstrapStorage", err) } store.AddStorer(dataRetriever.BootstrapUnit, bootstrapUnit) - miniBlockUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.MiniBlocksStorage, disabledCustomDatabaseRemover) + miniBlockUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.MiniBlocksStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } miniBlockUnit, err := psf.createPruningPersister(miniBlockUnitArgs) if err != nil { return fmt.Errorf("%w for MiniBlocksStorage", err) } store.AddStorer(dataRetriever.MiniBlockUnit, miniBlockUnit) - metaBlockUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.MetaBlockStorage, disabledCustomDatabaseRemover) + metaBlockUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.MetaBlockStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } metaBlockUnit, err := psf.createPruningPersister(metaBlockUnitArgs) if err != nil { return fmt.Errorf("%w for MetaBlockStorage", err) @@ -200,7 +224,10 @@ func (psf *StorageServiceFactory) createAndAddBaseStorageUnits( } store.AddStorer(dataRetriever.MetaHdrNonceHashDataUnit, metaHdrHashNonceUnit) - headerUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.BlockHeaderStorage, disabledCustomDatabaseRemover) + headerUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.BlockHeaderStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } headerUnit, err := psf.createPruningPersister(headerUnitArgs) if err != nil { return fmt.Errorf("%w for BlockHeaderStorage", err) @@ -213,14 +240,20 @@ func (psf *StorageServiceFactory) createAndAddBaseStorageUnits( } store.AddStorer(dataRetriever.UserAccountsUnit, userAccountsUnit) - userAccountsCheckpointsUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.AccountsTrieCheckpointsStorage, disabledCustomDatabaseRemover) + userAccountsCheckpointsUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.AccountsTrieCheckpointsStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } userAccountsCheckpointsUnit, err := psf.createPruningPersister(userAccountsCheckpointsUnitArgs) if err != nil { return fmt.Errorf("%w for AccountsTrieCheckpointsStorage", err) } store.AddStorer(dataRetriever.UserAccountsCheckpointsUnit, userAccountsCheckpointsUnit) - peerAccountsCheckpointsUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.PeerAccountsTrieCheckpointsStorage, disabledCustomDatabaseRemover) + peerAccountsCheckpointsUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.PeerAccountsTrieCheckpointsStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } peerAccountsCheckpointsUnit, err := psf.createPruningPersister(peerAccountsCheckpointsUnitArgs) if err != nil { return fmt.Errorf("%w for PeerAccountsTrieCheckpointsStorage", err) @@ -278,14 +311,20 @@ func (psf *StorageServiceFactory) CreateForShard() (dataRetriever.StorageService return nil, err } - peerAccountsUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.PeerAccountsTrieStorage, customDatabaseRemover) + peerAccountsUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.PeerAccountsTrieStorage, customDatabaseRemover) + if err != nil { + return nil, err + } peerAccountsUnit, err := psf.createTrieUnit(psf.generalConfig.PeerAccountsTrieStorage, peerAccountsUnitArgs) if err != nil { return nil, fmt.Errorf("%w for PeerAccountsTrieStorage", err) } store.AddStorer(dataRetriever.PeerAccountsUnit, peerAccountsUnit) - peerBlockUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.PeerBlockBodyStorage, disabledCustomDatabaseRemover) + peerBlockUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.PeerBlockBodyStorage, disabledCustomDatabaseRemover) + if err != nil { + return nil, err + } peerBlockUnit, err := psf.createPruningPersister(peerBlockUnitArgs) if err != nil { return nil, fmt.Errorf("%w for PeerBlockBodyStorage", err) @@ -377,7 +416,11 @@ func (psf *StorageServiceFactory) createTriePruningStorer( storageConfig config.StorageConfig, customDatabaseRemover storage.CustomDatabaseRemoverHandler, ) (storage.Storer, error) { - accountsUnitArgs := psf.createPruningStorerArgs(storageConfig, customDatabaseRemover) + accountsUnitArgs, err := psf.createPruningStorerArgs(storageConfig, customDatabaseRemover) + if err != nil { + return nil, err + } + if psf.storageType == ProcessStorageService && psf.nodeProcessingMode == common.Normal { accountsUnitArgs.PersistersTracker = pruning.NewTriePersisterTracker(accountsUnitArgs.EpochsData) } @@ -406,7 +449,10 @@ func (psf *StorageServiceFactory) setUpLogsAndEventsStorer(chainStorer *dataRetr shouldCreateStorer := psf.generalConfig.LogsAndEvents.SaveInStorageEnabled || psf.generalConfig.DbLookupExtensions.Enabled if shouldCreateStorer { var err error - txLogsUnitArgs := psf.createPruningStorerArgs(psf.generalConfig.LogsAndEvents.TxLogsStorage, disabled.NewDisabledCustomDatabaseRemover()) + txLogsUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.LogsAndEvents.TxLogsStorage, disabled.NewDisabledCustomDatabaseRemover()) + if err != nil { + return err + } txLogsUnit, err = psf.createPruningPersister(txLogsUnitArgs) if err != nil { return fmt.Errorf("%w for LogsAndEvents.TxLogsStorage", err) @@ -427,7 +473,10 @@ func (psf *StorageServiceFactory) setUpDbLookupExtensions(chainStorer *dataRetri // Create the eventsHashesByTxHash (PRUNING) storer eventsHashesByTxHashConfig := psf.generalConfig.DbLookupExtensions.ResultsHashesByTxHashStorageConfig - eventsHashesByTxHashStorerArgs := psf.createPruningStorerArgs(eventsHashesByTxHashConfig, disabled.NewDisabledCustomDatabaseRemover()) + eventsHashesByTxHashStorerArgs, err := psf.createPruningStorerArgs(eventsHashesByTxHashConfig, disabled.NewDisabledCustomDatabaseRemover()) + if err != nil { + return err + } eventsHashesByTxHashPruningStorer, err := psf.createPruningPersister(eventsHashesByTxHashStorerArgs) if err != nil { return fmt.Errorf("%w for DbLookupExtensions.ResultsHashesByTxHashStorageConfig", err) @@ -437,7 +486,10 @@ func (psf *StorageServiceFactory) setUpDbLookupExtensions(chainStorer *dataRetri // Create the miniblocksMetadata (PRUNING) storer miniblocksMetadataConfig := psf.generalConfig.DbLookupExtensions.MiniblocksMetadataStorageConfig - miniblocksMetadataPruningStorerArgs := psf.createPruningStorerArgs(miniblocksMetadataConfig, disabled.NewDisabledCustomDatabaseRemover()) + miniblocksMetadataPruningStorerArgs, err := psf.createPruningStorerArgs(miniblocksMetadataConfig, disabled.NewDisabledCustomDatabaseRemover()) + if err != nil { + return err + } miniblocksMetadataPruningStorer, err := psf.createPruningPersister(miniblocksMetadataPruningStorerArgs) if err != nil { return fmt.Errorf("%w for DbLookupExtensions.MiniblocksMetadataStorageConfig", err) @@ -519,7 +571,7 @@ func (psf *StorageServiceFactory) createEsdtSuppliesUnit(shardIDStr string) (sto func (psf *StorageServiceFactory) createPruningStorerArgs( storageConfig config.StorageConfig, customDatabaseRemover storage.CustomDatabaseRemoverHandler, -) pruning.StorerArgs { +) (pruning.StorerArgs, error) { numOfEpochsToKeep := uint32(psf.generalConfig.StoragePruning.NumEpochsToKeep) numOfActivePersisters := uint32(psf.generalConfig.StoragePruning.NumActivePersisters) pruningEnabled := psf.generalConfig.StoragePruning.Enabled @@ -530,6 +582,13 @@ func (psf *StorageServiceFactory) createPruningStorerArgs( NumOfEpochsToKeep: numOfEpochsToKeep, NumOfActivePersisters: numOfActivePersisters, } + + dbConfigHandler := NewDBConfigHandler(storageConfig.DB) + persisterFactory, err := NewPersisterFactory(dbConfigHandler) + if err != nil { + return pruning.StorerArgs{}, err + } + args := pruning.StorerArgs{ Identifier: storageConfig.DB.FilePath, PruningEnabled: pruningEnabled, @@ -539,7 +598,7 @@ func (psf *StorageServiceFactory) createPruningStorerArgs( CacheConf: GetCacherFromConfig(storageConfig.Cache), PathManager: psf.pathManager, DbPath: dbPath, - PersisterFactory: NewPersisterFactory(storageConfig.DB), + PersisterFactory: persisterFactory, Notifier: psf.epochStartNotifier, MaxBatchSize: storageConfig.DB.MaxBatchSize, EnabledDbLookupExtensions: psf.generalConfig.DbLookupExtensions.Enabled, @@ -547,7 +606,7 @@ func (psf *StorageServiceFactory) createPruningStorerArgs( EpochsData: epochsData, } - return args + return args, nil } func (psf *StorageServiceFactory) createTrieEpochRootHashStorerIfNeeded() (storage.Storer, error) { diff --git a/storage/interface.go b/storage/interface.go index 8f84cb400d6..d5bdc49c081 100644 --- a/storage/interface.go +++ b/storage/interface.go @@ -47,25 +47,7 @@ type Cacher interface { } // Persister provides storage of data services in a database like construct -type Persister interface { - // Put add the value to the (key, val) persistence medium - Put(key, val []byte) error - // Get gets the value associated to the key - Get(key []byte) ([]byte, error) - // Has returns true if the given key is present in the persistence medium - Has(key []byte) error - // Close closes the files/resources associated to the persistence medium - Close() error - // Remove removes the data associated to the given key - Remove(key []byte) error - // Destroy removes the persistence medium stored data - Destroy() error - // DestroyClosed removes the already closed persistence medium stored data - DestroyClosed() error - RangeKeys(handler func(key []byte, val []byte) bool) - // IsInterfaceNil returns true if there is no value under the interface - IsInterfaceNil() bool -} +type Persister = types.Persister // Batcher allows to batch the data first then write the batch to the persister in one go type Batcher interface { @@ -202,6 +184,24 @@ type AdaptedSizedLRUCache interface { IsInterfaceNil() bool } +// ShardIDProvider defines what a component which is able to provide persister id per key should do +type ShardIDProvider interface { + ComputeId(key []byte) uint32 + NumberOfShards() uint32 + GetShardIDs() []uint32 + IsInterfaceNil() bool +} + +// PersisterCreator defines the behavour of a component which is able to create a persister +type PersisterCreator = types.PersisterCreator + +// DBConfigHandler defines the behaviour of a component that will handle db config +type DBConfigHandler interface { + GetDBConfig(path string) (*config.DBConfig, error) + SaveDBConfigToFilePath(path string, dbConfig *config.DBConfig) error + IsInterfaceNil() bool +} + // ManagedPeersHolder defines the operations of an entity that holds managed identities for a node type ManagedPeersHolder interface { IsMultiKeyMode() bool diff --git a/storage/latestData/latestDataProvider.go b/storage/latestData/latestDataProvider.go index d372f81b43c..df6ea7e2418 100644 --- a/storage/latestData/latestDataProvider.go +++ b/storage/latestData/latestDataProvider.go @@ -132,7 +132,12 @@ func (ldp *latestDataProvider) getEpochDirs() ([]string, error) { } func (ldp *latestDataProvider) getLastEpochAndRoundFromStorage(parentDir string, lastEpoch uint32) (storage.LatestDataFromStorage, error) { - persisterFactory := factory.NewPersisterFactory(ldp.generalConfig.BootstrapStorage.DB) + dbConfigHandler := factory.NewDBConfigHandler(ldp.generalConfig.BootstrapStorage.DB) + persisterFactory, err := factory.NewPersisterFactory(dbConfigHandler) + if err != nil { + return storage.LatestDataFromStorage{}, err + } + pathWithoutShard := filepath.Join( parentDir, fmt.Sprintf("%s_%d", ldp.defaultEpochString, lastEpoch), diff --git a/storage/pruning/fullHistoryPruningStorer_test.go b/storage/pruning/fullHistoryPruningStorer_test.go index 62c2d0c3b8c..255512ce958 100644 --- a/storage/pruning/fullHistoryPruningStorer_test.go +++ b/storage/pruning/fullHistoryPruningStorer_test.go @@ -294,14 +294,19 @@ func TestFullHistoryPruningStorer_ConcurrentOperations(t *testing.T) { fmt.Println(testDir) args := getDefaultArgs() - args.PersisterFactory = factory.NewPersisterFactory(config.DBConfig{ - FilePath: filepath.Join(testDir, dbName), - Type: "LvlDBSerial", - MaxBatchSize: 100, - MaxOpenFiles: 10, - BatchDelaySeconds: 2, - }) - var err error + dbConfigHandler := factory.NewDBConfigHandler( + config.DBConfig{ + FilePath: filepath.Join(testDir, dbName), + Type: "LvlDBSerial", + MaxBatchSize: 100, + MaxOpenFiles: 10, + BatchDelaySeconds: 2, + }, + ) + persisterFactory, err := factory.NewPersisterFactory(dbConfigHandler) + require.Nil(t, err) + args.PersisterFactory = persisterFactory + args.PathManager, err = pathmanager.NewPathManager(testDir+"/epoch_[E]/shard_[S]/[I]", "shard_[S]/[I]", "db") require.NoError(t, err) fhArgs := pruning.FullHistoryStorerArgs{ diff --git a/storage/pruning/pruningStorer_test.go b/storage/pruning/pruningStorer_test.go index 113eaf6ab26..bd50e2b0681 100644 --- a/storage/pruning/pruningStorer_test.go +++ b/storage/pruning/pruningStorer_test.go @@ -1049,14 +1049,20 @@ func TestPruningStorer_ConcurrentOperations(t *testing.T) { fmt.Println(testDir) args := getDefaultArgs() - args.PersisterFactory = factory.NewPersisterFactory(config.DBConfig{ - FilePath: filepath.Join(testDir, dbName), - Type: "LvlDBSerial", - MaxBatchSize: 100, - MaxOpenFiles: 10, - BatchDelaySeconds: 2, - }) - var err error + + dbConfigHandler := factory.NewDBConfigHandler( + config.DBConfig{ + FilePath: filepath.Join(testDir, dbName), + Type: "LvlDBSerial", + MaxBatchSize: 100, + MaxOpenFiles: 10, + BatchDelaySeconds: 2, + }, + ) + persisterFactory, err := factory.NewPersisterFactory(dbConfigHandler) + require.Nil(t, err) + + args.PersisterFactory = persisterFactory args.PathManager, err = pathmanager.NewPathManager(testDir+"/epoch_[E]/shard_[S]/[I]", "shard_[S]/[I]", "db") require.NoError(t, err) diff --git a/storage/storageunit/constants.go b/storage/storageunit/constants.go index 91cf2f76eec..0e128af8123 100644 --- a/storage/storageunit/constants.go +++ b/storage/storageunit/constants.go @@ -9,8 +9,7 @@ const ( SizeLRUCache = storageUnit.SizeLRUCache ) -// LvlDB currently the only supported DBs -// More to be added +// DB types that are currently supported const ( // LvlDB represents a levelDB storage identifier LvlDB = storageUnit.LvlDB @@ -19,3 +18,8 @@ const ( // MemoryDB represents an in memory storage identifier MemoryDB = storageUnit.MemoryDB ) + +// Shard id provider types that are currently supported +const ( + BinarySplit = storageUnit.BinarySplit +) diff --git a/storage/storageunit/storageunit.go b/storage/storageunit/storageunit.go index ccb0eb6aa84..fc205e12a33 100644 --- a/storage/storageunit/storageunit.go +++ b/storage/storageunit/storageunit.go @@ -29,6 +29,9 @@ type CacheType = storageUnit.CacheType // DBType represents the type of the supported databases type DBType = storageUnit.DBType +// ShardIDProviderType represents the type of the supported shard id providers +type ShardIDProviderType = storageUnit.ShardIDProviderType + // NewStorageUnit is the constructor for the storage unit, creating a new storage unit // from the given cacher and persister. func NewStorageUnit(c storage.Cacher, p storage.Persister) (*Unit, error) { diff --git a/storage/storageunit/storageunit_test.go b/storage/storageunit/storageunit_test.go index 9b9b125fa7e..ff21f26e252 100644 --- a/storage/storageunit/storageunit_test.go +++ b/storage/storageunit/storageunit_test.go @@ -7,10 +7,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/storage/mock" "github.com/multiversx/mx-chain-go/storage/storageunit" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-storage-go/common" - - "github.com/multiversx/mx-chain-go/testscommon" "github.com/stretchr/testify/assert" ) @@ -155,7 +155,7 @@ func TestNewStorageCacherAdapter(t *testing.T) { cacher := &mock.AdaptedSizedLruCacheStub{} db := &mock.PersisterStub{} storedDataFactory := &storage.StoredDataFactoryStub{} - marshaller := &testscommon.MarshalizerStub{} + marshaller := &marshallerMock.MarshalizerStub{} t.Run("nil parameter should error", func(t *testing.T) { t.Parallel() diff --git a/testscommon/components/components.go b/testscommon/components/components.go index e713d3ed758..c979f7c2775 100644 --- a/testscommon/components/components.go +++ b/testscommon/components/components.go @@ -330,8 +330,8 @@ func GetStateFactoryArgs(coreComponents factory.CoreComponentsHolder) stateComp. trieStorageManagers[dataRetriever.PeerAccountsUnit.String()] = storageManagerPeer triesHolder := state.NewDataTriesHolder() - trieUsers, _ := trie.NewTrie(storageManagerUser, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), 5) - triePeers, _ := trie.NewTrie(storageManagerPeer, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), 5) + trieUsers, _ := trie.NewTrie(storageManagerUser, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), coreComponents.EnableEpochsHandler(), 5) + triePeers, _ := trie.NewTrie(storageManagerPeer, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), coreComponents.EnableEpochsHandler(), 5) triesHolder.Put([]byte(dataRetriever.UserAccountsUnit.String()), trieUsers) triesHolder.Put([]byte(dataRetriever.PeerAccountsUnit.String()), triePeers) diff --git a/testscommon/components/default.go b/testscommon/components/default.go index d90406199db..38df0c22211 100644 --- a/testscommon/components/default.go +++ b/testscommon/components/default.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/factory" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" @@ -27,9 +28,9 @@ import ( // GetDefaultCoreComponents - func GetDefaultCoreComponents() *mock.CoreComponentsMock { return &mock.CoreComponentsMock{ - IntMarsh: &testscommon.MarshalizerMock{}, - TxMarsh: &testscommon.MarshalizerMock{}, - VmMarsh: &testscommon.MarshalizerMock{}, + IntMarsh: &marshallerMock.MarshalizerMock{}, + TxMarsh: &marshallerMock.MarshalizerMock{}, + VmMarsh: &marshallerMock.MarshalizerMock{}, Hash: &testscommon.HasherStub{}, UInt64ByteSliceConv: testscommon.NewNonceHashConverterMock(), AddrPubKeyConv: testscommon.NewPubkeyConverterMock(32), diff --git a/testscommon/enableEpochsHandlerStub.go b/testscommon/enableEpochsHandlerMock/enableEpochsHandlerStub.go similarity index 99% rename from testscommon/enableEpochsHandlerStub.go rename to testscommon/enableEpochsHandlerMock/enableEpochsHandlerStub.go index 4fb26b750f8..038032a3fe1 100644 --- a/testscommon/enableEpochsHandlerStub.go +++ b/testscommon/enableEpochsHandlerMock/enableEpochsHandlerStub.go @@ -1,4 +1,4 @@ -package testscommon +package enableEpochsHandlerMock import "sync" @@ -122,6 +122,7 @@ type EnableEpochsHandlerStub struct { IsMultiClaimOnDelegationEnabledField bool IsChangeUsernameEnabledField bool IsConsistentTokensValuesLengthCheckEnabledField bool + IsAutoBalanceDataTriesEnabledField bool } // ResetPenalizedTooMuchGasFlag - @@ -1059,6 +1060,14 @@ func (stub *EnableEpochsHandlerStub) IsConsistentTokensValuesLengthCheckEnabled( return stub.IsConsistentTokensValuesLengthCheckEnabledField } +// IsAutoBalanceDataTriesEnabled - +func (stub *EnableEpochsHandlerStub) IsAutoBalanceDataTriesEnabled() bool { + stub.RLock() + defer stub.RUnlock() + + return stub.IsAutoBalanceDataTriesEnabledField +} + // IsInterfaceNil - func (stub *EnableEpochsHandlerStub) IsInterfaceNil() bool { return stub == nil diff --git a/testscommon/integrationtests/accountsFactory.go b/testscommon/integrationtests/accountsFactory.go index f6528564e94..a8912560f51 100644 --- a/testscommon/integrationtests/accountsFactory.go +++ b/testscommon/integrationtests/accountsFactory.go @@ -7,11 +7,12 @@ import ( // TestAccountFactory - type TestAccountFactory struct { + args state.ArgsAccountCreation } // CreateAccount - func (factory *TestAccountFactory) CreateAccount(address []byte) (vmcommon.AccountHandler, error) { - return state.NewUserAccount(address) + return state.NewUserAccount(address, factory.args) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/testscommon/integrationtests/factory.go b/testscommon/integrationtests/factory.go index 3a1302d43b5..350cc175177 100644 --- a/testscommon/integrationtests/factory.go +++ b/testscommon/integrationtests/factory.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/factory" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" testcommonStorage "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" @@ -58,7 +59,12 @@ func CreateStorer(parentDir string) storage.Storer { MaxBatchSize: 45000, MaxOpenFiles: 10, } - persisterFactory := factory.NewPersisterFactory(dbConfig) + dbConfigHandler := factory.NewDBConfigHandler(dbConfig) + persisterFactory, err := factory.NewPersisterFactory(dbConfigHandler) + if err != nil { + return nil + } + triePersister, err := persisterFactory.Create(parentDir) if err != nil { return nil @@ -74,11 +80,11 @@ func CreateStorer(parentDir string) storage.Storer { // CreateInMemoryShardAccountsDB - func CreateInMemoryShardAccountsDB() *state.AccountsDB { - return CreateAccountsDB(CreateMemUnit()) + return CreateAccountsDB(CreateMemUnit(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) } // CreateAccountsDB - -func CreateAccountsDB(db storage.Storer) *state.AccountsDB { +func CreateAccountsDB(db storage.Storer, enableEpochs common.EnableEpochsHandler) *state.AccountsDB { ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, HashesSize: 10000, @@ -92,14 +98,20 @@ func CreateAccountsDB(db storage.Storer) *state.AccountsDB { trieStorage, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorage, TestMarshalizer, TestHasher, MaxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, TestMarshalizer, TestHasher, enableEpochs, MaxTrieLevelInMemory) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 10) argsAccountsDB := state.ArgsAccountsDB{ - Trie: tr, - Hasher: TestHasher, - Marshaller: TestMarshalizer, - AccountFactory: &TestAccountFactory{}, + Trie: tr, + Hasher: TestHasher, + Marshaller: TestMarshalizer, + AccountFactory: &TestAccountFactory{ + args: state.ArgsAccountCreation{ + Hasher: TestHasher, + Marshaller: TestMarshalizer, + EnableEpochsHandler: enableEpochs, + }, + }, StoragePruningManager: spm, ProcessingMode: common.Normal, ProcessStatusHandler: &testscommon.ProcessStatusHandlerStub{}, diff --git a/testscommon/marshalizerMock.go b/testscommon/marshallerMock/marshalizerMock.go similarity index 97% rename from testscommon/marshalizerMock.go rename to testscommon/marshallerMock/marshalizerMock.go index 57684a4b79a..a92cc150528 100644 --- a/testscommon/marshalizerMock.go +++ b/testscommon/marshallerMock/marshalizerMock.go @@ -1,4 +1,4 @@ -package testscommon +package marshallerMock import ( "encoding/json" diff --git a/testscommon/marshalizerStub.go b/testscommon/marshallerMock/marshalizerStub.go similarity index 96% rename from testscommon/marshalizerStub.go rename to testscommon/marshallerMock/marshalizerStub.go index 8281a41278a..7f35509abcb 100644 --- a/testscommon/marshalizerStub.go +++ b/testscommon/marshallerMock/marshalizerStub.go @@ -1,4 +1,4 @@ -package testscommon +package marshallerMock // MarshalizerStub - type MarshalizerStub struct { diff --git a/testscommon/state/accountAdapterStub.go b/testscommon/state/accountAdapterStub.go index 8e9ec352a36..433722f7e21 100644 --- a/testscommon/state/accountAdapterStub.go +++ b/testscommon/state/accountAdapterStub.go @@ -1,6 +1,7 @@ package state import ( + "context" "math/big" "github.com/multiversx/mx-chain-go/common" @@ -34,6 +35,7 @@ type StateUserAccountHandlerStub struct { SetUserNameCalled func(userName []byte) GetUserNameCalled func() []byte IsGuardedCalled func() bool + GetAllLeavesCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context) error } // AddressBytes - @@ -236,6 +238,15 @@ func (aas *StateUserAccountHandlerStub) IsGuarded() bool { return false } +// GetAllLeaves - +func (aas *StateUserAccountHandlerStub) GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context) error { + if aas.GetAllLeavesCalled != nil { + return aas.GetAllLeavesCalled(leavesChannels, ctx) + } + + return nil +} + // IsInterfaceNil - func (aas *StateUserAccountHandlerStub) IsInterfaceNil() bool { return aas == nil diff --git a/testscommon/state/accountFactoryStub.go b/testscommon/state/accountFactoryStub.go index e41a84cc709..c5b9f482a03 100644 --- a/testscommon/state/accountFactoryStub.go +++ b/testscommon/state/accountFactoryStub.go @@ -1,17 +1,23 @@ package state -import "github.com/multiversx/mx-chain-vm-common-go" +import ( + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-vm-common-go" +) // TODO: move all the mocks from the mock package to testscommon // AccountsFactoryStub - type AccountsFactoryStub struct { - CreateAccountCalled func(address []byte) (vmcommon.AccountHandler, error) + CreateAccountCalled func(address []byte, hasher hashing.Hasher, marshaller marshal.Marshalizer) (vmcommon.AccountHandler, error) } // CreateAccount - func (afs *AccountsFactoryStub) CreateAccount(address []byte) (vmcommon.AccountHandler, error) { - return afs.CreateAccountCalled(address) + return afs.CreateAccountCalled(address, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/testscommon/state/accountWrapperMock.go b/testscommon/state/accountWrapperMock.go index 2e717f410ad..eab738712bd 100644 --- a/testscommon/state/accountWrapperMock.go +++ b/testscommon/state/accountWrapperMock.go @@ -2,10 +2,16 @@ package state import ( + "context" + "fmt" "math/big" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) @@ -20,9 +26,9 @@ type AccountWrapMock struct { CodeMetadata []byte RootHash []byte address []byte - trackableDataTrie state.DataTrieTracker Balance *big.Int guarded bool + trackableDataTrie state.DataTrieTracker SetNonceWithJournalCalled func(nonce uint64) error `json:"-"` SetCodeHashWithJournalCalled func(codeHash []byte) error `json:"-"` @@ -30,11 +36,22 @@ type AccountWrapMock struct { AccountDataHandlerCalled func() vmcommon.AccountDataHandler `json:"-"` } +var errInsufficientBalance = fmt.Errorf("insufficient balance") + // NewAccountWrapMock - func NewAccountWrapMock(adr []byte) *AccountWrapMock { + tdt, _ := state.NewTrackableDataTrie( + []byte("identifier"), + nil, + &hashingMocks.HasherMock{}, + &marshallerMock.MarshalizerMock{}, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ) + return &AccountWrapMock{ address: adr, - trackableDataTrie: state.NewTrackableDataTrie([]byte("identifier"), nil), + trackableDataTrie: tdt, + Balance: big.NewInt(0), } } @@ -53,12 +70,22 @@ func (awm *AccountWrapMock) GetUserName() []byte { } // AddToBalance - -func (awm *AccountWrapMock) AddToBalance(_ *big.Int) error { +func (awm *AccountWrapMock) AddToBalance(val *big.Int) error { + newBalance := big.NewInt(0).Add(awm.Balance, val) + if newBalance.Cmp(big.NewInt(0)) < 0 { + return errInsufficientBalance + } + awm.Balance = newBalance return nil } // SubFromBalance - -func (awm *AccountWrapMock) SubFromBalance(_ *big.Int) error { +func (awm *AccountWrapMock) SubFromBalance(val *big.Int) error { + newBalance := big.NewInt(0).Sub(awm.Balance, val) + if newBalance.Cmp(big.NewInt(0)) < 0 { + return errInsufficientBalance + } + awm.Balance = newBalance return nil } @@ -163,7 +190,7 @@ func (awm *AccountWrapMock) DataTrie() common.DataTrieHandler { } // SaveDirtyData - -func (awm *AccountWrapMock) SaveDirtyData(trie common.Trie) (map[string][]byte, error) { +func (awm *AccountWrapMock) SaveDirtyData(trie common.Trie) ([]core.TrieData, error) { return awm.trackableDataTrie.SaveDirtyData(trie) } @@ -194,3 +221,8 @@ func (awm *AccountWrapMock) GetNonce() uint64 { func (awm *AccountWrapMock) IsGuarded() bool { return awm.guarded } + +// GetAllLeaves - +func (awm *AccountWrapMock) GetAllLeaves(_ *common.TrieIteratorChannels, _ context.Context) error { + return nil +} diff --git a/testscommon/state/accountsAdapterStub.go b/testscommon/state/accountsAdapterStub.go index 5500fbb6e0d..b7bb5da7b35 100644 --- a/testscommon/state/accountsAdapterStub.go +++ b/testscommon/state/accountsAdapterStub.go @@ -30,7 +30,7 @@ type AccountsStub struct { SnapshotStateCalled func(rootHash []byte) SetStateCheckpointCalled func(rootHash []byte) IsPruningEnabledCalled func() bool - GetAllLeavesCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error + GetAllLeavesCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error RecreateAllTriesCalled func(rootHash []byte) (map[string]common.Trie, error) GetCodeCalled func([]byte) []byte GetTrieCalled func([]byte) (common.Trie, error) @@ -102,9 +102,9 @@ func (as *AccountsStub) SaveAccount(account vmcommon.AccountHandler) error { } // GetAllLeaves - -func (as *AccountsStub) GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { +func (as *AccountsStub) GetAllLeaves(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, trieLeafParser common.TrieLeafParser) error { if as.GetAllLeavesCalled != nil { - return as.GetAllLeavesCalled(leavesChannels, ctx, rootHash) + return as.GetAllLeavesCalled(leavesChannels, ctx, rootHash, trieLeafParser) } return nil } diff --git a/testscommon/state/userAccountStub.go b/testscommon/state/userAccountStub.go index 538316f0acd..b2d498064b7 100644 --- a/testscommon/state/userAccountStub.go +++ b/testscommon/state/userAccountStub.go @@ -2,8 +2,10 @@ package state import ( + "context" "math/big" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -168,7 +170,7 @@ func (u *UserAccountStub) IsGuarded() bool { } // SaveDirtyData - -func (u *UserAccountStub) SaveDirtyData(_ common.Trie) (map[string][]byte, error) { +func (u *UserAccountStub) SaveDirtyData(_ common.Trie) ([]core.TrieData, error) { return nil, nil } @@ -184,3 +186,8 @@ func (u *UserAccountStub) AccountDataHandler() vmcommon.AccountDataHandler { } return nil } + +// GetAllLeaves - +func (u *UserAccountStub) GetAllLeaves(_ *common.TrieIteratorChannels, _ context.Context) error { + return nil +} diff --git a/testscommon/trie/dataTrieMigratorStub.go b/testscommon/trie/dataTrieMigratorStub.go new file mode 100644 index 00000000000..57bab03dbc8 --- /dev/null +++ b/testscommon/trie/dataTrieMigratorStub.go @@ -0,0 +1,44 @@ +package trie + +import ( + "github.com/multiversx/mx-chain-core-go/core" +) + +// DataTrieMigratorStub - +type DataTrieMigratorStub struct { + ConsumeStorageLoadGasCalled func() bool + AddLeafToMigrationQueueCalled func(leafData core.TrieData, newLeafVersion core.TrieNodeVersion) (bool, error) + GetLeavesToBeMigratedCalled func() []core.TrieData +} + +// ConsumeStorageLoadGas - +func (d *DataTrieMigratorStub) ConsumeStorageLoadGas() bool { + if d.ConsumeStorageLoadGasCalled != nil { + return d.ConsumeStorageLoadGasCalled() + } + + return true +} + +// AddLeafToMigrationQueue - +func (d *DataTrieMigratorStub) AddLeafToMigrationQueue(leafData core.TrieData, newLeafVersion core.TrieNodeVersion) (bool, error) { + if d.AddLeafToMigrationQueueCalled != nil { + return d.AddLeafToMigrationQueueCalled(leafData, newLeafVersion) + } + + return true, nil +} + +// GetLeavesToBeMigrated - +func (d *DataTrieMigratorStub) GetLeavesToBeMigrated() []core.TrieData { + if d.GetLeavesToBeMigratedCalled != nil { + return d.GetLeavesToBeMigratedCalled() + } + + return nil +} + +// IsInterfaceNil - +func (d *DataTrieMigratorStub) IsInterfaceNil() bool { + return d == nil +} diff --git a/testscommon/trie/dataTrieTrackerStub.go b/testscommon/trie/dataTrieTrackerStub.go index a49fe811a4b..838c1611264 100644 --- a/testscommon/trie/dataTrieTrackerStub.go +++ b/testscommon/trie/dataTrieTrackerStub.go @@ -1,16 +1,20 @@ package trie import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) // DataTrieTrackerStub - type DataTrieTrackerStub struct { - RetrieveValueCalled func(key []byte) ([]byte, uint32, error) - SaveKeyValueCalled func(key []byte, value []byte) error - SetDataTrieCalled func(tr common.Trie) - DataTrieCalled func() common.Trie - SaveDirtyDataCalled func(trie common.Trie) (map[string][]byte, error) + RetrieveValueCalled func(key []byte) ([]byte, uint32, error) + SaveKeyValueCalled func(key []byte, value []byte) error + SetDataTrieCalled func(tr common.Trie) + DataTrieCalled func() common.Trie + SaveDirtyDataCalled func(trie common.Trie) ([]core.TrieData, error) + SaveTrieDataCalled func(trieData core.TrieData) error + MigrateDataTrieLeavesCalled func(args vmcommon.ArgsMigrateDataTrieLeaves) error } // RetrieveValue - @@ -48,12 +52,21 @@ func (dtts *DataTrieTrackerStub) DataTrie() common.DataTrieHandler { } // SaveDirtyData - -func (dtts *DataTrieTrackerStub) SaveDirtyData(mainTrie common.Trie) (map[string][]byte, error) { +func (dtts *DataTrieTrackerStub) SaveDirtyData(mainTrie common.Trie) ([]core.TrieData, error) { if dtts.SaveDirtyDataCalled != nil { return dtts.SaveDirtyDataCalled(mainTrie) } - return map[string][]byte{}, nil + return make([]core.TrieData, 0), nil +} + +// MigrateDataTrieLeaves - +func (dtts *DataTrieTrackerStub) MigrateDataTrieLeaves(args vmcommon.ArgsMigrateDataTrieLeaves) error { + if dtts.MigrateDataTrieLeavesCalled != nil { + return dtts.MigrateDataTrieLeavesCalled(args) + } + + return nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/testscommon/trie/statisticsMock.go b/testscommon/trie/statisticsMock.go index 11fbfe654da..6c6272c4aba 100644 --- a/testscommon/trie/statisticsMock.go +++ b/testscommon/trie/statisticsMock.go @@ -1,7 +1,7 @@ package trie import ( - "github.com/multiversx/mx-chain-go/trie/statistics" + "github.com/multiversx/mx-chain-go/common" ) // MockStatistics - @@ -25,5 +25,10 @@ func (m *MockStatistics) WaitForSnapshotsToFinish() { } // AddTrieStats - -func (m *MockStatistics) AddTrieStats(_ *statistics.TrieStatsDTO) { +func (m *MockStatistics) AddTrieStats(_ common.TrieStatisticsHandler, _ common.TrieType) { +} + +// IsInterfaceNil returns true if there is no value under the interface +func (m *MockStatistics) IsInterfaceNil() bool { + return m == nil } diff --git a/testscommon/trie/trieStub.go b/testscommon/trie/trieStub.go index b6707e2752e..81c90867e92 100644 --- a/testscommon/trie/trieStub.go +++ b/testscommon/trie/trieStub.go @@ -4,31 +4,36 @@ import ( "context" "errors" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) var errNotImplemented = errors.New("not implemented") // TrieStub - type TrieStub struct { - GetCalled func(key []byte) ([]byte, uint32, error) - UpdateCalled func(key, value []byte) error - DeleteCalled func(key []byte) error - RootCalled func() ([]byte, error) - CommitCalled func() error - RecreateCalled func(root []byte) (common.Trie, error) - RecreateFromEpochCalled func(options common.RootHashHolder) (common.Trie, error) - GetObsoleteHashesCalled func() [][]byte - AppendToOldHashesCalled func([][]byte) - GetSerializedNodesCalled func([]byte, uint64) ([][]byte, uint64, error) - GetAllHashesCalled func() ([][]byte, error) - GetAllLeavesOnChannelCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error - GetProofCalled func(key []byte) ([][]byte, []byte, error) - VerifyProofCalled func(rootHash []byte, key []byte, proof [][]byte) (bool, error) - GetStorageManagerCalled func() common.StorageManager - GetSerializedNodeCalled func(bytes []byte) ([]byte, error) - GetOldRootCalled func() []byte - CloseCalled func() error + GetCalled func(key []byte) ([]byte, uint32, error) + UpdateCalled func(key, value []byte) error + UpdateWithVersionCalled func(key, value []byte, version core.TrieNodeVersion) error + DeleteCalled func(key []byte) error + RootCalled func() ([]byte, error) + CommitCalled func() error + RecreateCalled func(root []byte) (common.Trie, error) + RecreateFromEpochCalled func(options common.RootHashHolder) (common.Trie, error) + GetObsoleteHashesCalled func() [][]byte + AppendToOldHashesCalled func([][]byte) + GetSerializedNodesCalled func([]byte, uint64) ([][]byte, uint64, error) + GetAllHashesCalled func() ([][]byte, error) + GetAllLeavesOnChannelCalled func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, trieLeafParser common.TrieLeafParser) error + GetProofCalled func(key []byte) ([][]byte, []byte, error) + VerifyProofCalled func(rootHash []byte, key []byte, proof [][]byte) (bool, error) + GetStorageManagerCalled func() common.StorageManager + GetSerializedNodeCalled func(bytes []byte) ([]byte, error) + GetOldRootCalled func() []byte + CloseCalled func() error + CollectLeavesForMigrationCalled func(args vmcommon.ArgsMigrateDataTrieLeaves) error + IsMigratedToLatestVersionCalled func() (bool, error) } // GetStorageManager - @@ -59,9 +64,9 @@ func (ts *TrieStub) VerifyProof(rootHash []byte, key []byte, proof [][]byte) (bo } // GetAllLeavesOnChannel - -func (ts *TrieStub) GetAllLeavesOnChannel(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { +func (ts *TrieStub) GetAllLeavesOnChannel(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, trieLeafParser common.TrieLeafParser) error { if ts.GetAllLeavesOnChannelCalled != nil { - return ts.GetAllLeavesOnChannelCalled(leavesChannels, ctx, rootHash, keyBuilder) + return ts.GetAllLeavesOnChannelCalled(leavesChannels, ctx, rootHash, keyBuilder, trieLeafParser) } return nil @@ -85,6 +90,24 @@ func (ts *TrieStub) Update(key, value []byte) error { return errNotImplemented } +// UpdateWithVersion - +func (ts *TrieStub) UpdateWithVersion(key []byte, value []byte, version core.TrieNodeVersion) error { + if ts.UpdateWithVersionCalled != nil { + return ts.UpdateWithVersionCalled(key, value, version) + } + + return errNotImplemented +} + +// CollectLeavesForMigration - +func (ts *TrieStub) CollectLeavesForMigration(args vmcommon.ArgsMigrateDataTrieLeaves) error { + if ts.CollectLeavesForMigrationCalled != nil { + return ts.CollectLeavesForMigrationCalled(args) + } + + return errNotImplemented +} + // Delete - func (ts *TrieStub) Delete(key []byte) error { if ts.DeleteCalled != nil { @@ -193,6 +216,15 @@ func (ts *TrieStub) GetOldRoot() []byte { return nil } +// IsMigratedToLatestVersion - +func (ts *TrieStub) IsMigratedToLatestVersion() (bool, error) { + if ts.IsMigratedToLatestVersionCalled != nil { + return ts.IsMigratedToLatestVersionCalled() + } + + return false, nil +} + // Close - func (ts *TrieStub) Close() error { if ts.CloseCalled != nil { diff --git a/trie/branchNode.go b/trie/branchNode.go index 66fa48e8d9e..01f1268d339 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) var _ = node(&branchNode{}) @@ -41,19 +42,18 @@ func newBranchNode(marshalizer marshal.Marshalizer, hasher hashing.Hasher) (*bra }, nil } -func emptyDirtyBranchNode() *branchNode { - var children [nrOfChildren]node - encChildren := make([][]byte, nrOfChildren) +func (bn *branchNode) setVersionForChild(version core.TrieNodeVersion, childPos byte) { + sliceNotInitialized := len(bn.ChildrenVersion) == 0 - return &branchNode{ - CollapsedBn: CollapsedBn{ - EncodedChildren: encChildren, - }, - children: children, - baseNode: &baseNode{ - dirty: true, - }, + if version == core.NotSpecified && sliceNotInitialized { + return + } + + if sliceNotInitialized { + bn.ChildrenVersion = make([]byte, nrOfChildren) } + + bn.ChildrenVersion[int(childPos)] = byte(version) } func (bn *branchNode) getHash() []byte { @@ -489,68 +489,79 @@ func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (node return bn.children[childPos], key, nil } -func (bn *branchNode) insert(n *leafNode, db common.TrieStorageInteractor) (node, [][]byte, error) { +func (bn *branchNode) insert(newData core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := bn.isEmptyOrNil() if err != nil { return nil, emptyHashes, fmt.Errorf("insert error %w", err) } - insertedKey := n.Key - if len(insertedKey) == 0 { + if len(newData.Key) == 0 { return nil, emptyHashes, ErrValueTooShort } - childPos := insertedKey[firstByte] + childPos := newData.Key[firstByte] if childPosOutOfRange(childPos) { return nil, emptyHashes, ErrChildPosOutOfRange } - n.Key = insertedKey[1:] + newData.Key = newData.Key[1:] err = resolveIfCollapsed(bn, childPos, db) if err != nil { return nil, emptyHashes, err } if bn.children[childPos] == nil { - return bn.insertOnNilChild(n, childPos) + return bn.insertOnNilChild(newData, childPos) } - return bn.insertOnExistingChild(n, childPos, db) + return bn.insertOnExistingChild(newData, childPos, db) } -func (bn *branchNode) insertOnNilChild(n *leafNode, childPos byte) (node, [][]byte, error) { - newLn, err := newLeafNode(n.Key, n.Value, bn.marsh, bn.hasher) +func (bn *branchNode) insertOnNilChild(newData core.TrieData, childPos byte) (node, [][]byte, error) { + newLn, err := newLeafNode(newData, bn.marsh, bn.hasher) if err != nil { return nil, [][]byte{}, err } modifiedHashes := make([][]byte, 0) - modifiedHashes = bn.modifyNodeAfterInsert(modifiedHashes, childPos, newLn) + modifiedHashes, err = bn.modifyNodeAfterInsert(modifiedHashes, childPos, newLn) + if err != nil { + return nil, [][]byte{}, err + } return bn, modifiedHashes, nil } -func (bn *branchNode) insertOnExistingChild(n *leafNode, childPos byte, db common.TrieStorageInteractor) (node, [][]byte, error) { - newNode, modifiedHashes, err := bn.children[childPos].insert(n, db) +func (bn *branchNode) insertOnExistingChild(newData core.TrieData, childPos byte, db common.TrieStorageInteractor) (node, [][]byte, error) { + newNode, modifiedHashes, err := bn.children[childPos].insert(newData, db) if check.IfNil(newNode) || err != nil { return nil, [][]byte{}, err } - modifiedHashes = bn.modifyNodeAfterInsert(modifiedHashes, childPos, newNode) + modifiedHashes, err = bn.modifyNodeAfterInsert(modifiedHashes, childPos, newNode) + if err != nil { + return nil, [][]byte{}, err + } return bn, modifiedHashes, nil } -func (bn *branchNode) modifyNodeAfterInsert(modifiedHashes [][]byte, childPos byte, newNode node) [][]byte { +func (bn *branchNode) modifyNodeAfterInsert(modifiedHashes [][]byte, childPos byte, newNode node) ([][]byte, error) { if !bn.dirty { modifiedHashes = append(modifiedHashes, bn.hash) } + childVersion, err := newNode.getVersion() + if err != nil { + return nil, err + } + bn.children[childPos] = newNode + bn.setVersionForChild(childVersion, childPos) bn.dirty = true bn.hash = nil - return modifiedHashes + return modifiedHashes, nil } func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, node, [][]byte, error) { @@ -585,10 +596,9 @@ func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, oldHashes = append(oldHashes, bn.hash) } - bn.hash = nil - bn.children[childPos] = newNode - if newNode == nil { - bn.EncodedChildren[childPos] = nil + err = bn.setNewChild(childPos, newNode) + if err != nil { + return false, nil, emptyHashes, err } numChildren, pos := getChildPosition(bn) @@ -622,6 +632,25 @@ func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, return true, bn, oldHashes, nil } +func (bn *branchNode) setNewChild(childPos byte, newNode node) error { + bn.hash = nil + bn.children[childPos] = newNode + if check.IfNil(newNode) { + bn.setVersionForChild(0, childPos) + bn.EncodedChildren[childPos] = nil + + return nil + } + + childVersion, err := newNode.getVersion() + if err != nil { + return err + } + bn.setVersionForChild(childVersion, childPos) + + return nil +} + func (bn *branchNode) reduceNode(pos int) (node, bool, error) { newEn, err := newExtensionNode([]byte{byte(pos)}, bn, bn.marsh, bn.hasher) if err != nil { @@ -781,6 +810,7 @@ func (bn *branchNode) loadChildren(getNode func([]byte) (node, error)) ([][]byte func (bn *branchNode) getAllLeavesOnChannel( leavesChannel chan core.KeyValueHolder, keyBuilder common.KeyBuilder, + trieLeafParser common.TrieLeafParser, db common.TrieStorageInteractor, marshalizer marshal.Marshalizer, chanClose chan struct{}, @@ -811,7 +841,7 @@ func (bn *branchNode) getAllLeavesOnChannel( clonedKeyBuilder := keyBuilder.Clone() clonedKeyBuilder.BuildKey([]byte{byte(i)}) - err = bn.children[i].getAllLeavesOnChannel(leavesChannel, clonedKeyBuilder, db, marshalizer, chanClose, ctx) + err = bn.children[i].getAllLeavesOnChannel(leavesChannel, clonedKeyBuilder, trieLeafParser, db, marshalizer, chanClose, ctx) if err != nil { return err } @@ -915,6 +945,91 @@ func (bn *branchNode) collectStats(ts common.TrieStatisticsHandler, depthLevel i return nil } +func (bn *branchNode) getVersion() (core.TrieNodeVersion, error) { + if len(bn.ChildrenVersion) == 0 { + return core.NotSpecified, nil + } + + index := 0 + var nodeVersion byte + for i := range bn.children { + index++ + if bn.children[i] == nil && len(bn.EncodedChildren[i]) == 0 { + continue + } + + nodeVersion = bn.ChildrenVersion[i] + break + } + + for i := index; i < len(bn.children); i++ { + if bn.children[i] == nil && len(bn.EncodedChildren[i]) == 0 { + continue + } + + if bn.ChildrenVersion[i] != nodeVersion { + return core.NotSpecified, nil + } + } + + return core.TrieNodeVersion(nodeVersion), nil +} + +func (bn *branchNode) getVersionForChild(childIndex byte) core.TrieNodeVersion { + if len(bn.ChildrenVersion) == 0 { + return core.NotSpecified + } + + return core.TrieNodeVersion(bn.ChildrenVersion[childIndex]) +} + +func (bn *branchNode) collectLeavesForMigration( + migrationArgs vmcommon.ArgsMigrateDataTrieLeaves, + db common.TrieStorageInteractor, + keyBuilder common.KeyBuilder, +) (bool, error) { + shouldContinue := migrationArgs.TrieMigrator.ConsumeStorageLoadGas() + if !shouldContinue { + return false, nil + } + + shouldMigrateNode, err := shouldMigrateCurrentNode(bn, migrationArgs) + if err != nil { + return false, err + } + if !shouldMigrateNode { + return true, nil + } + + for i := range bn.children { + if bn.children[i] == nil && len(bn.EncodedChildren[i]) == 0 { + continue + } + + if bn.getVersionForChild(byte(i)) != migrationArgs.OldVersion { + continue + } + + err = resolveIfCollapsed(bn, byte(i), db) + if err != nil { + return false, err + } + + clonedKeyBuilder := keyBuilder.Clone() + clonedKeyBuilder.BuildKey([]byte{byte(i)}) + shouldContinueMigrating, err := bn.children[i].collectLeavesForMigration(migrationArgs, db, clonedKeyBuilder) + if err != nil { + return false, err + } + + if !shouldContinueMigrating { + return false, nil + } + } + + return true, nil +} + // IsInterfaceNil returns true if there is no value under the interface func (bn *branchNode) IsInterfaceNil() bool { return bn == nil diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index e3f1118c61a..e2959add025 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -13,7 +13,9 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/storage/cache" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/stretchr/testify/assert" ) @@ -24,13 +26,21 @@ func getTestMarshalizerAndHasher() (marshal.Marshalizer, hashing.Hasher) { return marsh, hash } +func getTrieDataWithDefaultVersion(key string, val string) core.TrieData { + return core.TrieData{ + Key: []byte(key), + Value: []byte(val), + Version: core.NotSpecified, + } +} + func getBnAndCollapsedBn(marshalizer marshal.Marshalizer, hasher hashing.Hasher) (*branchNode, *branchNode) { var children [nrOfChildren]node EncodedChildren := make([][]byte, nrOfChildren) - children[2], _ = newLeafNode([]byte("dog"), []byte("dog"), marshalizer, hasher) - children[6], _ = newLeafNode([]byte("doe"), []byte("doe"), marshalizer, hasher) - children[13], _ = newLeafNode([]byte("doge"), []byte("doge"), marshalizer, hasher) + children[2], _ = newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), marshalizer, hasher) + children[6], _ = newLeafNode(getTrieDataWithDefaultVersion("doe", "doe"), marshalizer, hasher) + children[13], _ = newLeafNode(getTrieDataWithDefaultVersion("doge", "doge"), marshalizer, hasher) bn, _ := newBranchNode(marshalizer, hasher) bn.children = children @@ -43,6 +53,23 @@ func getBnAndCollapsedBn(marshalizer marshal.Marshalizer, hasher hashing.Hasher) return bn, collapsedBn } +func emptyDirtyBranchNode() *branchNode { + var children [nrOfChildren]node + encChildren := make([][]byte, nrOfChildren) + childrenVersion := make([]byte, nrOfChildren) + + return &branchNode{ + CollapsedBn: CollapsedBn{ + EncodedChildren: encChildren, + ChildrenVersion: childrenVersion, + }, + children: children, + baseNode: &baseNode{ + dirty: true, + }, + } +} + func newEmptyTrie() (*patriciaMerkleTrie, *trieStorageManager) { args := GetDefaultTrieStorageManagerParameters() trieStorage, _ := NewTrieStorageManager(args) @@ -54,6 +81,7 @@ func newEmptyTrie() (*patriciaMerkleTrie, *trieStorageManager) { oldRoot: make([]byte, 0), maxTrieLevelInMemory: 5, chanClose: make(chan struct{}), + enableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } return tr, trieStorage @@ -169,8 +197,8 @@ func TestBranchNode_setRootHash(t *testing.T) { trieStorage2, _ := NewTrieStorageManager(GetDefaultTrieStorageManagerParameters()) maxTrieLevelInMemory := uint(5) - tr1, _ := NewTrie(trieStorage1, marsh, hsh, maxTrieLevelInMemory) - tr2, _ := NewTrie(trieStorage2, marsh, hsh, maxTrieLevelInMemory) + tr1, _ := NewTrie(trieStorage1, marsh, hsh, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr2, _ := NewTrie(trieStorage2, marsh, hsh, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) maxIterations := 10000 for i := 0; i < maxIterations; i++ { @@ -399,7 +427,7 @@ func TestBranchNode_resolveCollapsed(t *testing.T) { _ = bn.setHash() _ = bn.commitDirty(0, 5, db, db) - resolved, _ := newLeafNode([]byte("dog"), []byte("dog"), bn.marsh, bn.hasher) + resolved, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher) resolved.dirty = false resolved.hash = bn.EncodedChildren[childPos] @@ -443,7 +471,7 @@ func TestBranchNode_isCollapsed(t *testing.T) { assert.True(t, collapsedBn.isCollapsed()) assert.False(t, bn.isCollapsed()) - collapsedBn.children[2], _ = newLeafNode([]byte("dog"), []byte("dog"), bn.marsh, bn.hasher) + collapsedBn.children[2], _ = newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher) assert.False(t, collapsedBn.isCollapsed()) } @@ -544,7 +572,7 @@ func TestBranchNode_getNext(t *testing.T) { t.Parallel() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - nextNode, _ := newLeafNode([]byte("dog"), []byte("dog"), bn.marsh, bn.hasher) + nextNode, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher) childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) @@ -587,14 +615,13 @@ func TestBranchNode_insert(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) nodeKey := []byte{0, 2, 3} - n, _ := newLeafNode(nodeKey, []byte("dogs"), bn.marsh, bn.hasher) - newBn, _, err := bn.insert(n, nil) + newBn, _, err := bn.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), nil) assert.NotNil(t, newBn) assert.Nil(t, err) nodeKeyRemainder := nodeKey[1:] - bn.children[0], _ = newLeafNode(nodeKeyRemainder, []byte("dogs"), bn.marsh, bn.hasher) + bn.children[0], _ = newLeafNode(getTrieDataWithDefaultVersion(string(nodeKeyRemainder), "dogs"), bn.marsh, bn.hasher) assert.Equal(t, bn, newBn) } @@ -602,9 +629,8 @@ func TestBranchNode_insertEmptyKey(t *testing.T) { t.Parallel() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - n, _ := newLeafNode([]byte{}, []byte("dogs"), bn.marsh, bn.hasher) - newBn, _, err := bn.insert(n, nil) + newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("", "dogs"), nil) assert.Equal(t, ErrValueTooShort, err) assert.Nil(t, newBn) } @@ -613,9 +639,8 @@ func TestBranchNode_insertChildPosOutOfRange(t *testing.T) { t.Parallel() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - n, _ := newLeafNode([]byte("dog"), []byte("dogs"), bn.marsh, bn.hasher) - newBn, _, err := bn.insert(n, nil) + newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("dog", "dogs"), nil) assert.Equal(t, ErrChildPosOutOfRange, err) assert.Nil(t, newBn) } @@ -627,12 +652,11 @@ func TestBranchNode_insertCollapsedNode(t *testing.T) { bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - n, _ := newLeafNode(key, []byte("dogs"), bn.marsh, bn.hasher) _ = bn.setHash() _ = bn.commitDirty(0, 5, db, db) - newBn, _, err := collapsedBn.insert(n, db) + newBn, _, err := collapsedBn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) assert.NotNil(t, newBn) assert.Nil(t, err) @@ -647,7 +671,6 @@ func TestBranchNode_insertInStoredBnOnExistingPos(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - n, _ := newLeafNode(key, []byte("dogs"), bn.marsh, bn.hasher) _ = bn.commitDirty(0, 5, db, db) bnHash := bn.getHash() @@ -655,7 +678,7 @@ func TestBranchNode_insertInStoredBnOnExistingPos(t *testing.T) { lnHash := ln.getHash() expectedHashes := [][]byte{lnHash, bnHash} - newNode, oldHashes, err := bn.insert(n, db) + newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -668,13 +691,12 @@ func TestBranchNode_insertInStoredBnOnNilPos(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) nilChildPos := byte(11) key := append([]byte{nilChildPos}, []byte("dog")...) - n, _ := newLeafNode(key, []byte("dogs"), bn.marsh, bn.hasher) _ = bn.commitDirty(0, 5, db, db) bnHash := bn.getHash() expectedHashes := [][]byte{bnHash} - newNode, oldHashes, err := bn.insert(n, db) + newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -686,9 +708,8 @@ func TestBranchNode_insertInDirtyBnOnNilPos(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) nilChildPos := byte(11) key := append([]byte{nilChildPos}, []byte("dog")...) - n, _ := newLeafNode(key, []byte("dogs"), bn.marsh, bn.hasher) - newNode, oldHashes, err := bn.insert(n, nil) + newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -700,9 +721,8 @@ func TestBranchNode_insertInDirtyBnOnExistingPos(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - n, _ := newLeafNode(key, []byte("dogs"), bn.marsh, bn.hasher) - newNode, oldHashes, err := bn.insert(n, nil) + newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -713,7 +733,7 @@ func TestBranchNode_insertInNilNode(t *testing.T) { var bn *branchNode - newBn, _, err := bn.insert(&leafNode{}, nil) + newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("key", "dogs"), nil) assert.True(t, errors.Is(err, ErrNilBranchNode)) assert.Nil(t, newBn) } @@ -723,8 +743,8 @@ func TestBranchNode_delete(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) var children [nrOfChildren]node - children[6], _ = newLeafNode([]byte("doe"), []byte("doe"), bn.marsh, bn.hasher) - children[13], _ = newLeafNode([]byte("doge"), []byte("doge"), bn.marsh, bn.hasher) + children[6], _ = newLeafNode(getTrieDataWithDefaultVersion("doe", "doe"), bn.marsh, bn.hasher) + children[13], _ = newLeafNode(getTrieDataWithDefaultVersion("doge", "doge"), bn.marsh, bn.hasher) expectedBn, _ := newBranchNode(bn.marsh, bn.hasher) expectedBn.children = children @@ -851,12 +871,12 @@ func TestBranchNode_deleteAndReduceBn(t *testing.T) { var children [nrOfChildren]node firstChildPos := byte(2) secondChildPos := byte(6) - children[firstChildPos], _ = newLeafNode([]byte("dog"), []byte("dog"), bn.marsh, bn.hasher) - children[secondChildPos], _ = newLeafNode([]byte("doe"), []byte("doe"), bn.marsh, bn.hasher) + children[firstChildPos], _ = newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher) + children[secondChildPos], _ = newLeafNode(getTrieDataWithDefaultVersion("doe", "doe"), bn.marsh, bn.hasher) bn.children = children key := append([]byte{firstChildPos}, []byte("dog")...) - ln, _ := newLeafNode(key, []byte("dog"), bn.marsh, bn.hasher) + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string(key), "dog"), bn.marsh, bn.hasher) key = append([]byte{secondChildPos}, []byte("doe")...) dirty, newBn, _, err := bn.delete(key, nil) @@ -871,11 +891,11 @@ func TestBranchNode_reduceNode(t *testing.T) { bn, _ := newBranchNode(getTestMarshalizerAndHasher()) var children [nrOfChildren]node childPos := byte(2) - children[childPos], _ = newLeafNode([]byte("dog"), []byte("dog"), bn.marsh, bn.hasher) + children[childPos], _ = newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher) bn.children = children key := append([]byte{childPos}, []byte("dog")...) - ln, _ := newLeafNode(key, []byte("dog"), bn.marsh, bn.hasher) + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string(key), "dog"), bn.marsh, bn.hasher) n, newChildHash, err := bn.children[childPos].reduceNode(int(childPos)) assert.Equal(t, ln, n) @@ -1110,7 +1130,7 @@ func TestBranchNode_newBranchNodeNilMarshalizerShouldErr(t *testing.T) { func TestBranchNode_newBranchNodeNilHasherShouldErr(t *testing.T) { t.Parallel() - bn, err := newBranchNode(&testscommon.MarshalizerMock{}, nil) + bn, err := newBranchNode(&marshallerMock.MarshalizerMock{}, nil) assert.Nil(t, bn) assert.Equal(t, ErrNilHasher, err) } @@ -1133,7 +1153,7 @@ func TestBranchNode_newBranchNodeOkVals(t *testing.T) { func TestBranchNode_getMarshalizer(t *testing.T) { t.Parallel() - expectedMarsh := &testscommon.MarshalizerMock{} + expectedMarsh := &marshallerMock.MarshalizerMock{} bn := &branchNode{ baseNode: &baseNode{ marsh: expectedMarsh, @@ -1335,6 +1355,62 @@ func TestBranchNode_commitSnapshotDbIsClosing(t *testing.T) { assert.Equal(t, 0, len(missingNodesChan)) } +func TestBranchNode_getVersion(t *testing.T) { + t.Parallel() + + t.Run("nil ChildrenVersion", func(t *testing.T) { + t.Parallel() + + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + + version, err := bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + assert.Nil(t, err) + }) + + t.Run("NotSpecified for all children", func(t *testing.T) { + t.Parallel() + + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.ChildrenVersion = make([]byte, nrOfChildren) + bn.ChildrenVersion[2] = byte(core.NotSpecified) + bn.ChildrenVersion[6] = byte(core.NotSpecified) + bn.ChildrenVersion[13] = byte(core.NotSpecified) + + version, err := bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + assert.Nil(t, err) + }) + + t.Run("one child with autoBalanceEnabled", func(t *testing.T) { + t.Parallel() + + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.ChildrenVersion = make([]byte, nrOfChildren) + bn.ChildrenVersion[2] = byte(core.NotSpecified) + bn.ChildrenVersion[6] = byte(core.AutoBalanceEnabled) + bn.ChildrenVersion[13] = byte(core.NotSpecified) + + version, err := bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + assert.Nil(t, err) + }) + + t.Run("AutoBalanceEnabled for all children", func(t *testing.T) { + t.Parallel() + + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.ChildrenVersion = make([]byte, nrOfChildren) + bn.ChildrenVersion[2] = byte(core.AutoBalanceEnabled) + bn.ChildrenVersion[6] = byte(core.AutoBalanceEnabled) + bn.ChildrenVersion[13] = byte(core.AutoBalanceEnabled) + + version, err := bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + assert.Nil(t, err) + }) +} + func TestBranchNode_getValueReturnsEmptyByteSlice(t *testing.T) { t.Parallel() diff --git a/trie/doubleListSync_test.go b/trie/doubleListSync_test.go index c0a453242b9..65197f171fc 100644 --- a/trie/doubleListSync_test.go +++ b/trie/doubleListSync_test.go @@ -15,12 +15,14 @@ import ( "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -var marshalizer = &testscommon.MarshalizerMock{} +var marshalizer = &marshallerMock.MarshalizerMock{} var hasherMock = &hashingMocks.HasherMock{} func createMemUnit() storage.Storer { @@ -46,7 +48,7 @@ func createTrieStorageManager(store storage.Storer) (common.StorageManager, stor func createInMemoryTrie() (common.Trie, storage.Storer) { memUnit := createMemUnit() tsm, _ := createTrieStorageManager(memUnit) - tr, _ := NewTrie(tsm, marshalizer, hasherMock, 6) + tr, _ := NewTrie(tsm, marshalizer, hasherMock, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 6) return tr, memUnit } @@ -59,7 +61,7 @@ func createInMemoryTrieFromDB(db storage.Persister) (common.Trie, storage.Storer unit, _ := storageunit.NewStorageUnit(cache, db) tsm, _ := createTrieStorageManager(unit) - tr, _ := NewTrie(tsm, marshalizer, hasherMock, 6) + tr, _ := NewTrie(tsm, marshalizer, hasherMock, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 6) return tr, unit } diff --git a/trie/errors.go b/trie/errors.go index dc229f1c1b0..5e7c6d7973d 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -117,3 +117,12 @@ var ErrNilTrieIteratorErrChannel = errors.New("nil trie iterator error channel") // ErrInvalidIdentifier signals that an invalid identifier was provided var ErrInvalidIdentifier = errors.New("invalid identifier") + +// ErrNilKeyBuilder signals that a nil key builder has been provided +var ErrNilKeyBuilder = errors.New("nil key builder") + +// ErrNilTrieLeafParser signals that a nil trie leaf parser has been provided +var ErrNilTrieLeafParser = errors.New("nil trie leaf parser") + +// ErrInvalidNodeVersion signals that an invalid node version has been provided +var ErrInvalidNodeVersion = errors.New("invalid node version provided") diff --git a/trie/extensionNode.go b/trie/extensionNode.go index 04871193be8..42c081d6eb6 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "io" + "math" "strings" "sync" @@ -14,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) var _ = node(&extensionNode{}) @@ -25,11 +27,20 @@ func newExtensionNode(key []byte, child node, marshalizer marshal.Marshalizer, h if check.IfNil(hasher) { return nil, ErrNilHasher } + if check.IfNil(child) { + return nil, ErrNilNode + } + + childVersion, err := child.getVersion() + if err != nil { + return nil, err + } return &extensionNode{ CollapsedEn: CollapsedEn{ Key: key, EncodedChild: nil, + ChildVersion: uint32(childVersion), }, child: child, baseNode: &baseNode{ @@ -372,7 +383,7 @@ func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (n return en.child, key, nil } -func (en *extensionNode) insert(n *leafNode, db common.TrieStorageInteractor) (node, [][]byte, error) { +func (en *extensionNode) insert(newData core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := en.isEmptyOrNil() if err != nil { @@ -383,21 +394,21 @@ func (en *extensionNode) insert(n *leafNode, db common.TrieStorageInteractor) (n return nil, emptyHashes, err } - keyMatchLen := prefixLen(n.Key, en.Key) + keyMatchLen := prefixLen(newData.Key, en.Key) // If the whole key matches, keep this extension node as is // and only update the value. if keyMatchLen == len(en.Key) { - return en.insertInSameEn(n, keyMatchLen, db) + return en.insertInSameEn(newData, keyMatchLen, db) } // Otherwise branch out at the index where they differ. - return en.insertInNewBn(n, keyMatchLen) + return en.insertInNewBn(newData, keyMatchLen) } -func (en *extensionNode) insertInSameEn(n *leafNode, keyMatchLen int, db common.TrieStorageInteractor) (node, [][]byte, error) { - n.Key = n.Key[keyMatchLen:] - newNode, oldHashes, err := en.child.insert(n, db) +func (en *extensionNode) insertInSameEn(newData core.TrieData, keyMatchLen int, db common.TrieStorageInteractor) (node, [][]byte, error) { + newData.Key = newData.Key[keyMatchLen:] + newNode, oldHashes, err := en.child.insert(newData, db) if check.IfNil(newNode) || err != nil { return nil, [][]byte{}, err } @@ -414,7 +425,7 @@ func (en *extensionNode) insertInSameEn(n *leafNode, keyMatchLen int, db common. return newEn, oldHashes, nil } -func (en *extensionNode) insertInNewBn(n *leafNode, keyMatchLen int) (node, [][]byte, error) { +func (en *extensionNode) insertInNewBn(newData core.TrieData, keyMatchLen int) (node, [][]byte, error) { oldHash := make([][]byte, 0) if !en.dirty { oldHash = append(oldHash, en.hash) @@ -426,23 +437,20 @@ func (en *extensionNode) insertInNewBn(n *leafNode, keyMatchLen int) (node, [][] } oldChildPos := en.Key[keyMatchLen] - newChildPos := n.Key[keyMatchLen] + newChildPos := newData.Key[keyMatchLen] if childPosOutOfRange(oldChildPos) || childPosOutOfRange(newChildPos) { return nil, [][]byte{}, ErrChildPosOutOfRange } - followingExtensionNode, err := newExtensionNode(en.Key[keyMatchLen+1:], en.child, en.marsh, en.hasher) + err = en.insertOldChildInBn(bn, oldChildPos, keyMatchLen) if err != nil { return nil, [][]byte{}, err } - if len(followingExtensionNode.Key) < 1 { - bn.children[oldChildPos] = en.child - } else { - bn.children[oldChildPos] = followingExtensionNode + err = en.insertNewChildInBn(bn, newData, newChildPos, keyMatchLen) + if err != nil { + return nil, [][]byte{}, err } - n.Key = n.Key[keyMatchLen+1:] - bn.children[newChildPos] = n if keyMatchLen == 0 { return bn, oldHash, nil @@ -456,6 +464,41 @@ func (en *extensionNode) insertInNewBn(n *leafNode, keyMatchLen int) (node, [][] return newEn, oldHash, nil } +func (en *extensionNode) insertOldChildInBn(bn *branchNode, oldChildPos byte, keyMatchLen int) error { + keyReminder := en.Key[keyMatchLen+1:] + childVersion, err := en.child.getVersion() + if err != nil { + return err + } + bn.setVersionForChild(childVersion, oldChildPos) + + if len(keyReminder) < 1 { + bn.children[oldChildPos] = en.child + return nil + } + + followingExtensionNode, err := newExtensionNode(en.Key[keyMatchLen+1:], en.child, en.marsh, en.hasher) + if err != nil { + return err + } + + bn.children[oldChildPos] = followingExtensionNode + return nil +} + +func (en *extensionNode) insertNewChildInBn(bn *branchNode, newData core.TrieData, newChildPos byte, keyMatchLen int) error { + newData.Key = newData.Key[keyMatchLen+1:] + + newLeaf, err := newLeafNode(newData, en.marsh, en.hasher) + if err != nil { + return err + } + + bn.children[newChildPos] = newLeaf + bn.setVersionForChild(newData.Version, newChildPos) + return nil +} + func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bool, node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := en.isEmptyOrNil() @@ -483,29 +526,38 @@ func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bo oldHashes = append(oldHashes, en.hash) } - var n node switch newNode := newNode.(type) { case *leafNode: - n, err = newLeafNode(concat(en.Key, newNode.Key...), newNode.Value, en.marsh, en.hasher) + newLeafData := core.TrieData{ + Key: concat(en.Key, newNode.Key...), + Value: newNode.Value, + Version: core.TrieNodeVersion(newNode.Version), + } + n, err := newLeafNode(newLeafData, en.marsh, en.hasher) if err != nil { return false, nil, emptyHashes, err } return true, n, oldHashes, nil case *extensionNode: - n, err = newExtensionNode(concat(en.Key, newNode.Key...), newNode.child, en.marsh, en.hasher) + n, err := newExtensionNode(concat(en.Key, newNode.Key...), newNode.child, en.marsh, en.hasher) if err != nil { return false, nil, emptyHashes, err } return true, n, oldHashes, nil - default: - n, err = newExtensionNode(en.Key, newNode, en.marsh, en.hasher) + case *branchNode: + n, err := newExtensionNode(en.Key, newNode, en.marsh, en.hasher) if err != nil { return false, nil, emptyHashes, err } return true, n, oldHashes, nil + case nil: + log.Warn("nil child after deleting from extension node") + return true, nil, oldHashes, nil + default: + return false, nil, oldHashes, ErrInvalidNode } } @@ -639,6 +691,7 @@ func (en *extensionNode) loadChildren(getNode func([]byte) (node, error)) ([][]b func (en *extensionNode) getAllLeavesOnChannel( leavesChannel chan core.KeyValueHolder, keyBuilder common.KeyBuilder, + trieLeafParser common.TrieLeafParser, db common.TrieStorageInteractor, marshalizer marshal.Marshalizer, chanClose chan struct{}, @@ -663,7 +716,7 @@ func (en *extensionNode) getAllLeavesOnChannel( } keyBuilder.BuildKey(en.Key) - err = en.child.getAllLeavesOnChannel(leavesChannel, keyBuilder.Clone(), db, marshalizer, chanClose, ctx) + err = en.child.getAllLeavesOnChannel(leavesChannel, keyBuilder.Clone(), trieLeafParser, db, marshalizer, chanClose, ctx) if err != nil { return err } @@ -747,6 +800,42 @@ func (en *extensionNode) collectStats(ts common.TrieStatisticsHandler, depthLeve return nil } +func (en *extensionNode) getVersion() (core.TrieNodeVersion, error) { + if en.ChildVersion > math.MaxUint8 { + log.Warn("invalid trie node version for extension node", "child version", en.ChildVersion, "max version", math.MaxUint8) + return core.NotSpecified, ErrInvalidNodeVersion + } + + return core.TrieNodeVersion(en.ChildVersion), nil +} + +func (en *extensionNode) collectLeavesForMigration( + migrationArgs vmcommon.ArgsMigrateDataTrieLeaves, + db common.TrieStorageInteractor, + keyBuilder common.KeyBuilder, +) (bool, error) { + hasEnoughGasToContinueMigration := migrationArgs.TrieMigrator.ConsumeStorageLoadGas() + if !hasEnoughGasToContinueMigration { + return false, nil + } + + shouldMigrateNode, err := shouldMigrateCurrentNode(en, migrationArgs) + if err != nil { + return false, err + } + if !shouldMigrateNode { + return true, nil + } + + err = resolveIfCollapsed(en, 0, db) + if err != nil { + return false, err + } + + keyBuilder.BuildKey(en.Key) + return en.child.collectLeavesForMigration(migrationArgs, db, keyBuilder.Clone()) +} + // IsInterfaceNil returns true if there is no value under the interface func (en *extensionNode) IsInterfaceNil() bool { return en == nil diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index f24f8edbf14..ac243f3aaff 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "math" "testing" "github.com/multiversx/mx-chain-core-go/core" @@ -11,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/cache" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/stretchr/testify/assert" ) @@ -363,7 +365,7 @@ func TestExtensionNode_isCollapsed(t *testing.T) { assert.True(t, collapsedEn.isCollapsed()) assert.False(t, en.isCollapsed()) - collapsedEn.child, _ = newLeafNode([]byte("og"), []byte("dog"), en.marsh, en.hasher) + collapsedEn.child, _ = newLeafNode(getTrieDataWithDefaultVersion("og", "dog"), en.marsh, en.hasher) assert.False(t, collapsedEn.isCollapsed()) } @@ -490,9 +492,8 @@ func TestExtensionNode_insert(t *testing.T) { en, _ := getEnAndCollapsedEn() key := []byte{100, 15, 5, 6} - n, _ := newLeafNode(key, []byte("dogs"), en.marsh, en.hasher) - newNode, _, err := en.insert(n, nil) + newNode, _, err := en.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), nil) assert.NotNil(t, newNode) assert.Nil(t, err) @@ -506,12 +507,11 @@ func TestExtensionNode_insertCollapsedNode(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() key := []byte{100, 15, 5, 6} - n, _ := newLeafNode(key, []byte("dogs"), en.marsh, en.hasher) _ = en.setHash() _ = en.commitDirty(0, 5, db, db) - newNode, _, err := collapsedEn.insert(n, db) + newNode, _, err := collapsedEn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) assert.NotNil(t, newNode) assert.Nil(t, err) @@ -526,7 +526,6 @@ func TestExtensionNode_insertInStoredEnSameKey(t *testing.T) { en, _ := getEnAndCollapsedEn() enKey := []byte{100} key := append(enKey, []byte{11, 12}...) - n, _ := newLeafNode(key, []byte("dogs"), en.marsh, en.hasher) _ = en.commitDirty(0, 5, db, db) enHash := en.getHash() @@ -534,7 +533,7 @@ func TestExtensionNode_insertInStoredEnSameKey(t *testing.T) { bnHash := bn.getHash() expectedHashes := [][]byte{bnHash, enHash} - newNode, oldHashes, err := en.insert(n, db) + newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -548,12 +547,11 @@ func TestExtensionNode_insertInStoredEnDifferentKey(t *testing.T) { enKey := []byte{1} en, _ := newExtensionNode(enKey, bn, bn.marsh, bn.hasher) nodeKey := []byte{11, 12} - n, _ := newLeafNode(nodeKey, []byte("dogs"), bn.marsh, bn.hasher) _ = en.commitDirty(0, 5, db, db) expectedHashes := [][]byte{en.getHash()} - newNode, oldHashes, err := en.insert(n, db) + newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -564,9 +562,8 @@ func TestExtensionNode_insertInDirtyEnSameKey(t *testing.T) { en, _ := getEnAndCollapsedEn() nodeKey := []byte{100, 11, 12} - n, _ := newLeafNode(nodeKey, []byte("dogs"), en.marsh, en.hasher) - newNode, oldHashes, err := en.insert(n, nil) + newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -579,9 +576,8 @@ func TestExtensionNode_insertInDirtyEnDifferentKey(t *testing.T) { enKey := []byte{1} en, _ := newExtensionNode(enKey, bn, bn.marsh, bn.hasher) nodeKey := []byte{11, 12} - n, _ := newLeafNode(nodeKey, []byte("dogs"), bn.marsh, bn.hasher) - newNode, oldHashes, err := en.insert(n, nil) + newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -592,7 +588,7 @@ func TestExtensionNode_insertInNilNode(t *testing.T) { var en *extensionNode - newNode, _, err := en.insert(&leafNode{}, nil) + newNode, _, err := en.insert(getTrieDataWithDefaultVersion("key", "val"), nil) assert.Nil(t, newNode) assert.True(t, errors.Is(err, ErrNilExtensionNode)) assert.Nil(t, newNode) @@ -716,11 +712,13 @@ func TestExtensionNode_reduceNode(t *testing.T) { t.Parallel() marsh, hasher := getTestMarshalizerAndHasher() - en, _ := newExtensionNode([]byte{100, 111, 103}, nil, marsh, hasher) + bn, _ := getBnAndCollapsedBn(marsh, hasher) + en, _ := newExtensionNode([]byte{100, 111, 103}, bn, marsh, hasher) expected := &extensionNode{CollapsedEn: CollapsedEn{Key: []byte{2, 100, 111, 103}}, baseNode: &baseNode{dirty: true}} expected.marsh = en.marsh expected.hasher = en.hasher + expected.child = en.child n, newChildPos, err := en.reduceNode(2) assert.Equal(t, expected, n) @@ -850,7 +848,7 @@ func TestExtensionNode_newExtensionNodeNilMarshalizerShouldErr(t *testing.T) { func TestExtensionNode_newExtensionNodeNilHasherShouldErr(t *testing.T) { t.Parallel() - en, err := newExtensionNode([]byte("key"), &branchNode{}, &testscommon.MarshalizerMock{}, nil) + en, err := newExtensionNode([]byte("key"), &branchNode{}, &marshallerMock.MarshalizerMock{}, nil) assert.Nil(t, en) assert.Equal(t, ErrNilHasher, err) } @@ -860,7 +858,7 @@ func TestExtensionNode_newExtensionNodeOkVals(t *testing.T) { marsh, hasher := getTestMarshalizerAndHasher() key := []byte("key") - child := &branchNode{} + child, _ := getBnAndCollapsedBn(marsh, hasher) en, err := newExtensionNode(key, child, marsh, hasher) assert.Nil(t, err) @@ -1047,3 +1045,39 @@ func TestExtensionNode_commitSnapshotDbIsClosing(t *testing.T) { assert.Nil(t, err) assert.Equal(t, 0, len(missingNodesChan)) } + +func TestExtensionNode_getVersion(t *testing.T) { + t.Parallel() + + t.Run("invalid node version", func(t *testing.T) { + t.Parallel() + + en, _ := getEnAndCollapsedEn() + en.ChildVersion = math.MaxUint8 + 1 + + version, err := en.getVersion() + assert.Equal(t, core.NotSpecified, version) + assert.Equal(t, ErrInvalidNodeVersion, err) + }) + + t.Run("NotSpecified version", func(t *testing.T) { + t.Parallel() + + en, _ := getEnAndCollapsedEn() + + version, err := en.getVersion() + assert.Equal(t, core.NotSpecified, version) + assert.Nil(t, err) + }) + + t.Run("AutoBalanceEnabled version", func(t *testing.T) { + t.Parallel() + + en, _ := getEnAndCollapsedEn() + en.ChildVersion = uint32(core.AutoBalanceEnabled) + + version, err := en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + assert.Nil(t, err) + }) +} diff --git a/trie/factory/interface.go b/trie/factory/interface.go index 1c1b9aabbfd..f1d483d5756 100644 --- a/trie/factory/interface.go +++ b/trie/factory/interface.go @@ -12,4 +12,5 @@ type coreComponentsHandler interface { Hasher() hashing.Hasher PathHandler() storage.PathManagerHandler ProcessStatusHandler() common.ProcessStatusHandler + EnableEpochsHandler() common.EnableEpochsHandler } diff --git a/trie/factory/trieCreator.go b/trie/factory/trieCreator.go index 2958e9edccf..28353ef6fc0 100644 --- a/trie/factory/trieCreator.go +++ b/trie/factory/trieCreator.go @@ -24,6 +24,7 @@ type TrieCreateArgs struct { MaxTrieLevelInMem uint IdleProvider trie.IdleNodeProvider Identifier string + EnableEpochsHandler common.EnableEpochsHandler } type trieCreator struct { @@ -82,7 +83,7 @@ func (tc *trieCreator) Create(args TrieCreateArgs) (common.StorageManager, commo return nil, nil, err } - newTrie, err := trie.NewTrie(trieStorage, tc.marshalizer, tc.hasher, args.MaxTrieLevelInMem) + newTrie, err := trie.NewTrie(trieStorage, tc.marshalizer, tc.hasher, args.EnableEpochsHandler, args.MaxTrieLevelInMem) if err != nil { return nil, nil, err } @@ -143,6 +144,7 @@ func CreateTriesComponentsForShardId( SnapshotsEnabled: snapshotsEnabled, IdleProvider: coreComponentsHolder.ProcessStatusHandler(), Identifier: dataRetriever.UserAccountsUnit.String(), + EnableEpochsHandler: coreComponentsHolder.EnableEpochsHandler(), } userStorageManager, userAccountTrie, err := trFactory.Create(args) if err != nil { @@ -174,6 +176,7 @@ func CreateTriesComponentsForShardId( SnapshotsEnabled: snapshotsEnabled, IdleProvider: coreComponentsHolder.ProcessStatusHandler(), Identifier: dataRetriever.PeerAccountsUnit.String(), + EnableEpochsHandler: coreComponentsHolder.EnableEpochsHandler(), } peerStorageManager, peerAccountsTrie, err := trFactory.Create(args) if err != nil { diff --git a/trie/factory/trieCreator_test.go b/trie/factory/trieCreator_test.go index 55bba27cea4..6d72b2f2819 100644 --- a/trie/factory/trieCreator_test.go +++ b/trie/factory/trieCreator_test.go @@ -11,7 +11,9 @@ import ( "github.com/multiversx/mx-chain-go/integrationTests/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/factory" @@ -21,7 +23,7 @@ import ( func getArgs() factory.TrieFactoryArgs { return factory.TrieFactoryArgs{ - Marshalizer: &testscommon.MarshalizerMock{}, + Marshalizer: &marshallerMock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, PathManager: &testscommon.PathManagerStub{}, TrieStorageManagerConfig: config.TrieStorageManagerConfig{SnapshotsGoroutineNum: 1}, @@ -38,6 +40,7 @@ func getCreateArgs() factory.TrieCreateArgs { MaxTrieLevelInMem: 5, IdleProvider: &testscommon.ProcessStatusHandlerStub{}, Identifier: dataRetriever.UserAccountsUnit.String(), + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } @@ -194,10 +197,11 @@ func TestTrieCreator_CreateTriesComponentsForShardId(t *testing.T) { false, testscommon.GetGeneralConfig(), &mock.CoreComponentsStub{ - InternalMarshalizerField: &testscommon.MarshalizerMock{}, + InternalMarshalizerField: &marshallerMock.MarshalizerMock{}, HasherField: &hashingMocks.HasherMock{}, PathHandlerField: &testscommon.PathManagerStub{}, ProcessStatusHandlerInternal: &testscommon.ProcessStatusHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, &storageStubs.ChainStorerStub{ GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { @@ -219,10 +223,11 @@ func testWithMissingStorer(missingUnit dataRetriever.UnitType) func(t *testing.T false, testscommon.GetGeneralConfig(), &mock.CoreComponentsStub{ - InternalMarshalizerField: &testscommon.MarshalizerMock{}, + InternalMarshalizerField: &marshallerMock.MarshalizerMock{}, HasherField: &hashingMocks.HasherMock{}, PathHandlerField: &testscommon.PathManagerStub{}, ProcessStatusHandlerInternal: &testscommon.ProcessStatusHandlerStub{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, &storageStubs.ChainStorerStub{ GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { diff --git a/trie/interface.go b/trie/interface.go index 50c17b33a1f..a3120986239 100644 --- a/trie/interface.go +++ b/trie/interface.go @@ -29,7 +29,7 @@ type node interface { hashChildren() error tryGet(key []byte, depth uint32, db common.TrieStorageInteractor) ([]byte, uint32, error) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) - insert(n *leafNode, db common.TrieStorageInteractor) (node, [][]byte, error) + insert(newData core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) delete(key []byte, db common.TrieStorageInteractor) (bool, node, [][]byte, error) reduceNode(pos int) (node, bool, error) isEmptyOrNil() error @@ -39,10 +39,12 @@ type node interface { isValid() bool setDirty(bool) loadChildren(func([]byte) (node, error)) ([][]byte, []node, error) - getAllLeavesOnChannel(chan core.KeyValueHolder, common.KeyBuilder, common.TrieStorageInteractor, marshal.Marshalizer, chan struct{}, context.Context) error + getAllLeavesOnChannel(chan core.KeyValueHolder, common.KeyBuilder, common.TrieLeafParser, common.TrieStorageInteractor, marshal.Marshalizer, chan struct{}, context.Context) error getAllHashes(db common.TrieStorageInteractor) ([][]byte, error) getNextHashAndKey([]byte) (bool, []byte, []byte) getValue() []byte + getVersion() (core.TrieNodeVersion, error) + collectLeavesForMigration(migrationArgs vmcommon.ArgsMigrateDataTrieLeaves, db common.TrieStorageInteractor, keyBuilder common.KeyBuilder) (bool, error) commitDirty(level byte, maxTrieLevelInMemory uint, originDb common.TrieStorageInteractor, targetDb common.BaseStorer) error commitCheckpoint(originDb common.TrieStorageInteractor, targetDb common.BaseStorer, checkpointHashes CheckpointHashesHolder, leavesChan chan core.KeyValueHolder, ctx context.Context, stats common.TrieStatisticsHandler, idleProvider IdleNodeProvider, depthLevel int) error diff --git a/trie/keyBuilder/disabledKeyBuilder.go b/trie/keyBuilder/disabledKeyBuilder.go index b9f1d11c2bc..a930f4baff1 100644 --- a/trie/keyBuilder/disabledKeyBuilder.go +++ b/trie/keyBuilder/disabledKeyBuilder.go @@ -26,3 +26,8 @@ func (dkb *disabledKeyBuilder) GetKey() ([]byte, error) { func (dkb *disabledKeyBuilder) Clone() common.KeyBuilder { return &disabledKeyBuilder{} } + +// IsInterfaceNil returns true if there is no value under the interface +func (dkb *disabledKeyBuilder) IsInterfaceNil() bool { + return dkb == nil +} diff --git a/trie/keyBuilder/keyBuilder.go b/trie/keyBuilder/keyBuilder.go index 0c832c011b6..787b1d66e0e 100644 --- a/trie/keyBuilder/keyBuilder.go +++ b/trie/keyBuilder/keyBuilder.go @@ -59,3 +59,8 @@ func hexToTrieKeyBytes(hex []byte) ([]byte, error) { return key, nil } + +// IsInterfaceNil returns true if there is no value under the interface +func (kb *keyBuilder) IsInterfaceNil() bool { + return kb == nil +} diff --git a/trie/leafNode.go b/trie/leafNode.go index e20a38d4afd..9dcf1a2f3b9 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "io" + "math" "sync" "github.com/multiversx/mx-chain-core-go/core" @@ -14,11 +15,16 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) var _ = node(&leafNode{}) -func newLeafNode(key, value []byte, marshalizer marshal.Marshalizer, hasher hashing.Hasher) (*leafNode, error) { +func newLeafNode( + newData core.TrieData, + marshalizer marshal.Marshalizer, + hasher hashing.Hasher, +) (*leafNode, error) { if check.IfNil(marshalizer) { return nil, ErrNilMarshalizer } @@ -28,8 +34,9 @@ func newLeafNode(key, value []byte, marshalizer marshal.Marshalizer, hasher hash return &leafNode{ CollapsedLn: CollapsedLn{ - Key: key, - Value: value, + Key: newData.Key, + Value: newData.Value, + Version: uint32(newData.Version), }, baseNode: &baseNode{ dirty: true, @@ -168,7 +175,12 @@ func (ln *leafNode) commitCheckpoint( return err } - stats.AddLeafNode(depthLevel, uint64(nodeSize)) + version, err := ln.getVersion() + if err != nil { + return err + } + + stats.AddLeafNode(depthLevel, uint64(nodeSize), version) return nil } @@ -201,7 +213,12 @@ func (ln *leafNode) commitSnapshot( return err } - stats.AddLeafNode(depthLevel, uint64(nodeSize)) + version, err := ln.getVersion() + if err != nil { + return err + } + + stats.AddLeafNode(depthLevel, uint64(nodeSize), version) return nil } @@ -269,7 +286,7 @@ func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (node, [ } return nil, nil, ErrNodeNotFound } -func (ln *leafNode) insert(n *leafNode, _ common.TrieStorageInteractor) (node, [][]byte, error) { +func (ln *leafNode) insert(newData core.TrieData, _ common.TrieStorageInteractor) (node, [][]byte, error) { err := ln.isEmptyOrNil() if err != nil { return nil, [][]byte{}, fmt.Errorf("insert error %w", err) @@ -280,15 +297,14 @@ func (ln *leafNode) insert(n *leafNode, _ common.TrieStorageInteractor) (node, [ oldHash = append(oldHash, ln.hash) } - insertedKey := n.Key nodeKey := ln.Key - if bytes.Equal(insertedKey, nodeKey) { - return ln.insertInSameLn(n, oldHash) + if bytes.Equal(newData.Key, nodeKey) { + return ln.insertInSameLn(newData, oldHash) } - keyMatchLen := prefixLen(insertedKey, nodeKey) - bn, err := ln.insertInNewBn(n, keyMatchLen) + keyMatchLen := prefixLen(newData.Key, nodeKey) + bn, err := ln.insertInNewBn(newData, keyMatchLen) if err != nil { return nil, [][]byte{}, err } @@ -305,40 +321,54 @@ func (ln *leafNode) insert(n *leafNode, _ common.TrieStorageInteractor) (node, [ return newEn, oldHash, nil } -func (ln *leafNode) insertInSameLn(n *leafNode, oldHashes [][]byte) (node, [][]byte, error) { - if bytes.Equal(ln.Value, n.Value) { +func (ln *leafNode) insertInSameLn(newData core.TrieData, oldHashes [][]byte) (node, [][]byte, error) { + if bytes.Equal(ln.Value, newData.Value) { return nil, [][]byte{}, nil } - ln.Value = n.Value + ln.Value = newData.Value + ln.Version = uint32(newData.Version) ln.dirty = true ln.hash = nil return ln, oldHashes, nil } -func (ln *leafNode) insertInNewBn(n *leafNode, keyMatchLen int) (node, error) { +func (ln *leafNode) insertInNewBn(newData core.TrieData, keyMatchLen int) (node, error) { bn, err := newBranchNode(ln.marsh, ln.hasher) if err != nil { return nil, err } oldChildPos := ln.Key[keyMatchLen] - newChildPos := n.Key[keyMatchLen] + newChildPos := newData.Key[keyMatchLen] if childPosOutOfRange(oldChildPos) || childPosOutOfRange(newChildPos) { return nil, ErrChildPosOutOfRange } - newLnOldChildPos, err := newLeafNode(ln.Key[keyMatchLen+1:], ln.Value, ln.marsh, ln.hasher) + oldLnVersion, err := ln.getVersion() + if err != nil { + return nil, err + } + + oldLnData := core.TrieData{ + Key: ln.Key[keyMatchLen+1:], + Value: ln.Value, + Version: oldLnVersion, + } + newLnOldChildPos, err := newLeafNode(oldLnData, ln.marsh, ln.hasher) if err != nil { return nil, err } bn.children[oldChildPos] = newLnOldChildPos + bn.setVersionForChild(oldLnVersion, oldChildPos) - newLnNewChildPos, err := newLeafNode(n.Key[keyMatchLen+1:], n.Value, ln.marsh, ln.hasher) + newData.Key = newData.Key[keyMatchLen+1:] + newLnNewChildPos, err := newLeafNode(newData, ln.marsh, ln.hasher) if err != nil { return nil, err } bn.children[newChildPos] = newLnNewChildPos + bn.setVersionForChild(newData.Version, newChildPos) return bn, nil } @@ -358,7 +388,18 @@ func (ln *leafNode) delete(key []byte, _ common.TrieStorageInteractor) (bool, no func (ln *leafNode) reduceNode(pos int) (node, bool, error) { k := append([]byte{byte(pos)}, ln.Key...) - newLn, err := newLeafNode(k, ln.Value, ln.marsh, ln.hasher) + oldLnVersion, err := ln.getVersion() + if err != nil { + return nil, false, err + } + + oldLnData := core.TrieData{ + Key: k, + Value: ln.Value, + Version: oldLnVersion, + } + + newLn, err := newLeafNode(oldLnData, ln.marsh, ln.hasher) if err != nil { return nil, false, err } @@ -427,6 +468,7 @@ func (ln *leafNode) loadChildren(_ func([]byte) (node, error)) ([][]byte, []node func (ln *leafNode) getAllLeavesOnChannel( leavesChannel chan core.KeyValueHolder, keyBuilder common.KeyBuilder, + trieLeafParser common.TrieLeafParser, _ common.TrieStorageInteractor, _ marshal.Marshalizer, chanClose chan struct{}, @@ -443,7 +485,16 @@ func (ln *leafNode) getAllLeavesOnChannel( return err } - trieLeaf := keyValStorage.NewKeyValStorage(nodeKey, ln.Value) + version, err := ln.getVersion() + if err != nil { + return err + } + + trieLeaf, err := trieLeafParser.ParseLeaf(nodeKey, ln.Value, version) + if err != nil { + return err + } + for { select { case <-chanClose: @@ -505,10 +556,62 @@ func (ln *leafNode) collectStats(ts common.TrieStatisticsHandler, depthLevel int return err } - ts.AddLeafNode(depthLevel, uint64(len(val))) + version, err := ln.getVersion() + if err != nil { + return err + } + + ts.AddLeafNode(depthLevel, uint64(len(val)), version) return nil } +func (ln *leafNode) getVersion() (core.TrieNodeVersion, error) { + if ln.Version > math.MaxUint8 { + log.Warn("invalid trie node version", "version", ln.Version, "max version", math.MaxUint8) + return core.NotSpecified, ErrInvalidNodeVersion + } + + return core.TrieNodeVersion(ln.Version), nil +} + +func (ln *leafNode) collectLeavesForMigration( + migrationArgs vmcommon.ArgsMigrateDataTrieLeaves, + _ common.TrieStorageInteractor, + keyBuilder common.KeyBuilder, +) (bool, error) { + shouldContinue := migrationArgs.TrieMigrator.ConsumeStorageLoadGas() + if !shouldContinue { + return false, nil + } + + shouldMigrateNode, err := shouldMigrateCurrentNode(ln, migrationArgs) + if err != nil { + return false, err + } + if !shouldMigrateNode { + return true, nil + } + + keyBuilder.BuildKey(ln.Key) + key, err := keyBuilder.GetKey() + if err != nil { + return false, err + } + + version, err := ln.getVersion() + if err != nil { + return false, err + } + + leafData := core.TrieData{ + Key: key, + Value: ln.Value, + Version: version, + } + + return migrationArgs.TrieMigrator.AddLeafToMigrationQueue(leafData, migrationArgs.NewVersion) +} + // IsInterfaceNil returns true if there is no value under the interface func (ln *leafNode) IsInterfaceNil() bool { return ln == nil diff --git a/trie/leafNode_test.go b/trie/leafNode_test.go index bf9cab8209b..c40d1cf1a7d 100644 --- a/trie/leafNode_test.go +++ b/trie/leafNode_test.go @@ -3,6 +3,7 @@ package trie import ( "context" "errors" + "math" "testing" "github.com/multiversx/mx-chain-core-go/core" @@ -11,12 +12,13 @@ import ( "github.com/multiversx/mx-chain-go/storage/cache" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/stretchr/testify/assert" ) func getLn(marsh marshal.Marshalizer, hasher hashing.Hasher) *leafNode { - newLn, _ := newLeafNode([]byte("dog"), []byte("dog"), marsh, hasher) + newLn, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), marsh, hasher) return newLn } @@ -35,7 +37,7 @@ func TestLeafNode_newLeafNode(t *testing.T) { hasher: hasher, }, } - ln, _ := newLeafNode([]byte("dog"), []byte("dog"), marsh, hasher) + ln, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), marsh, hasher) assert.Equal(t, expectedLn, ln) } @@ -317,16 +319,15 @@ func TestLeafNode_insertAtSameKey(t *testing.T) { t.Parallel() ln := getLn(getTestMarshalizerAndHasher()) - key := []byte("dog") - expectedVal := []byte("dogs") - n, _ := newLeafNode(key, expectedVal, ln.marsh, ln.hasher) + key := "dog" + expectedVal := "dogs" - newNode, _, err := ln.insert(n, nil) + newNode, _, err := ln.insert(getTrieDataWithDefaultVersion(key, expectedVal), nil) assert.NotNil(t, newNode) assert.Nil(t, err) - val, _, _ := newNode.tryGet(key, 0, nil) - assert.Equal(t, expectedVal, val) + val, _, _ := newNode.tryGet([]byte(key), 0, nil) + assert.Equal(t, []byte(expectedVal), val) } func TestLeafNode_insertAtDifferentKey(t *testing.T) { @@ -335,13 +336,12 @@ func TestLeafNode_insertAtDifferentKey(t *testing.T) { marsh, hasher := getTestMarshalizerAndHasher() lnKey := []byte{2, 100, 111, 103} - ln, _ := newLeafNode(lnKey, []byte("dog"), marsh, hasher) + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string(lnKey), "dog"), marsh, hasher) nodeKey := []byte{3, 4, 5} nodeVal := []byte{3, 4, 5} - n, _ := newLeafNode(nodeKey, nodeVal, marsh, hasher) - newNode, _, err := ln.insert(n, nil) + newNode, _, err := ln.insert(getTrieDataWithDefaultVersion(string(nodeKey), string(nodeVal)), nil) assert.NotNil(t, newNode) assert.Nil(t, err) @@ -355,11 +355,10 @@ func TestLeafNode_insertInStoredLnAtSameKey(t *testing.T) { db := testscommon.NewMemDbMock() ln := getLn(getTestMarshalizerAndHasher()) - n, _ := newLeafNode([]byte("dog"), []byte("dogs"), ln.marsh, ln.hasher) _ = ln.commitDirty(0, 5, db, db) lnHash := ln.getHash() - newNode, oldHashes, err := ln.insert(n, db) + newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{lnHash}, oldHashes) @@ -370,12 +369,11 @@ func TestLeafNode_insertInStoredLnAtDifferentKey(t *testing.T) { db := testscommon.NewMemDbMock() marsh, hasher := getTestMarshalizerAndHasher() - ln, _ := newLeafNode([]byte{1, 2, 3}, []byte("dog"), marsh, hasher) - n, _ := newLeafNode([]byte{4, 5, 6}, []byte("dogs"), marsh, hasher) + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3}), "dog"), marsh, hasher) _ = ln.commitDirty(0, 5, db, db) lnHash := ln.getHash() - newNode, oldHashes, err := ln.insert(n, db) + newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion(string([]byte{4, 5, 6}), "dogs"), db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{lnHash}, oldHashes) @@ -385,9 +383,8 @@ func TestLeafNode_insertInDirtyLnAtSameKey(t *testing.T) { t.Parallel() ln := getLn(getTestMarshalizerAndHasher()) - n, _ := newLeafNode([]byte("dog"), []byte("dogs"), ln.marsh, ln.hasher) - newNode, oldHashes, err := ln.insert(n, nil) + newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -397,10 +394,9 @@ func TestLeafNode_insertInDirtyLnAtDifferentKey(t *testing.T) { t.Parallel() marsh, hasher := getTestMarshalizerAndHasher() - ln, _ := newLeafNode([]byte{1, 2, 3}, []byte("dog"), marsh, hasher) - n, _ := newLeafNode([]byte{4, 5, 6}, []byte("dogs"), marsh, hasher) + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3}), "dog"), marsh, hasher) - newNode, oldHashes, err := ln.insert(n, nil) + newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion(string([]byte{4, 5, 6}), "dogs"), nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -411,7 +407,7 @@ func TestLeafNode_insertInNilNode(t *testing.T) { var ln *leafNode - newNode, _, err := ln.insert(&leafNode{}, nil) + newNode, _, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), nil) assert.Nil(t, newNode) assert.True(t, errors.Is(err, ErrNilLeafNode)) assert.Nil(t, newNode) @@ -483,8 +479,8 @@ func TestLeafNode_reduceNode(t *testing.T) { t.Parallel() marsh, hasher := getTestMarshalizerAndHasher() - ln, _ := newLeafNode([]byte{100, 111, 103}, nil, marsh, hasher) - expected, _ := newLeafNode([]byte{2, 100, 111, 103}, nil, marsh, hasher) + ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{100, 111, 103}), ""), marsh, hasher) + expected, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{2, 100, 111, 103}), ""), marsh, hasher) expected.dirty = true n, newChildHash, err := ln.reduceNode(2) @@ -610,7 +606,7 @@ func TestLeafNode_deleteDifferentKeyShouldNotModifyTrie(t *testing.T) { func TestLeafNode_newLeafNodeNilMarshalizerShouldErr(t *testing.T) { t.Parallel() - ln, err := newLeafNode([]byte("key"), []byte("val"), nil, &hashingMocks.HasherMock{}) + ln, err := newLeafNode(getTrieDataWithDefaultVersion("key", "val"), nil, &hashingMocks.HasherMock{}) assert.Nil(t, ln) assert.Equal(t, ErrNilMarshalizer, err) } @@ -618,7 +614,7 @@ func TestLeafNode_newLeafNodeNilMarshalizerShouldErr(t *testing.T) { func TestLeafNode_newLeafNodeNilHasherShouldErr(t *testing.T) { t.Parallel() - ln, err := newLeafNode([]byte("key"), []byte("val"), &testscommon.MarshalizerMock{}, nil) + ln, err := newLeafNode(getTrieDataWithDefaultVersion("key", "val"), &marshallerMock.MarshalizerMock{}, nil) assert.Nil(t, ln) assert.Equal(t, ErrNilHasher, err) } @@ -629,7 +625,7 @@ func TestLeafNode_newLeafNodeOkVals(t *testing.T) { marsh, hasher := getTestMarshalizerAndHasher() key := []byte("key") val := []byte("val") - ln, err := newLeafNode(key, val, marsh, hasher) + ln, err := newLeafNode(getTrieDataWithDefaultVersion("key", "val"), marsh, hasher) assert.Nil(t, err) assert.Equal(t, key, ln.Key) @@ -744,3 +740,39 @@ func TestLeafNode_getValue(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) assert.Equal(t, ln.Value, ln.getValue()) } + +func TestLeafNode_getVersion(t *testing.T) { + t.Parallel() + + t.Run("invalid node version", func(t *testing.T) { + t.Parallel() + + ln := getLn(getTestMarshalizerAndHasher()) + ln.Version = math.MaxUint8 + 1 + + version, err := ln.getVersion() + assert.Equal(t, core.NotSpecified, version) + assert.Equal(t, ErrInvalidNodeVersion, err) + }) + + t.Run("NotSpecified version", func(t *testing.T) { + t.Parallel() + + ln := getLn(getTestMarshalizerAndHasher()) + + version, err := ln.getVersion() + assert.Equal(t, core.NotSpecified, version) + assert.Nil(t, err) + }) + + t.Run("AutoBalanceEnabled version", func(t *testing.T) { + t.Parallel() + + ln := getLn(getTestMarshalizerAndHasher()) + ln.Version = uint32(core.AutoBalanceEnabled) + + version, err := ln.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + assert.Nil(t, err) + }) +} diff --git a/trie/mock/keyBuilderStub.go b/trie/mock/keyBuilderStub.go index b0c0fdbc24a..8ba29de2213 100644 --- a/trie/mock/keyBuilderStub.go +++ b/trie/mock/keyBuilderStub.go @@ -33,3 +33,8 @@ func (stub *KeyBuilderStub) Clone() common.KeyBuilder { return &KeyBuilderStub{} } + +// IsInterfaceNil - +func (stub *KeyBuilderStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/trie/node.go b/trie/node.go index 617aab8b528..0a3a4545e3f 100644 --- a/trie/node.go +++ b/trie/node.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/trie/keyBuilder" logger "github.com/multiversx/mx-chain-logger-go" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) const ( @@ -279,3 +280,23 @@ func treatCommitSnapshotError(err error, hash []byte, missingNodesChan chan []by log.Error("error during trie snapshot", "err", err.Error(), "hash", hash) missingNodesChan <- hash } + +func shouldMigrateCurrentNode( + currentNode node, + migrationArgs vmcommon.ArgsMigrateDataTrieLeaves, +) (bool, error) { + version, err := currentNode.getVersion() + if err != nil { + return false, err + } + + if version == migrationArgs.NewVersion { + return false, nil + } + + if version != migrationArgs.OldVersion && version != core.NotSpecified { + return false, nil + } + + return true, nil +} diff --git a/trie/node.pb.go b/trie/node.pb.go index 0fba59b3106..ee618b1ff67 100644 --- a/trie/node.pb.go +++ b/trie/node.pb.go @@ -28,6 +28,7 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type CollapsedBn struct { EncodedChildren [][]byte `protobuf:"bytes,1,rep,name=EncodedChildren,proto3" json:"EncodedChildren,omitempty"` + ChildrenVersion []byte `protobuf:"bytes,2,opt,name=ChildrenVersion,proto3" json:"ChildrenVersion,omitempty"` } func (m *CollapsedBn) Reset() { *m = CollapsedBn{} } @@ -65,9 +66,17 @@ func (m *CollapsedBn) GetEncodedChildren() [][]byte { return nil } +func (m *CollapsedBn) GetChildrenVersion() []byte { + if m != nil { + return m.ChildrenVersion + } + return nil +} + type CollapsedEn struct { Key []byte `protobuf:"bytes,1,opt,name=Key,proto3" json:"Key,omitempty"` EncodedChild []byte `protobuf:"bytes,2,opt,name=EncodedChild,proto3" json:"EncodedChild,omitempty"` + ChildVersion uint32 `protobuf:"varint,3,opt,name=ChildVersion,proto3" json:"ChildVersion,omitempty"` } func (m *CollapsedEn) Reset() { *m = CollapsedEn{} } @@ -112,9 +121,17 @@ func (m *CollapsedEn) GetEncodedChild() []byte { return nil } +func (m *CollapsedEn) GetChildVersion() uint32 { + if m != nil { + return m.ChildVersion + } + return 0 +} + type CollapsedLn struct { - Key []byte `protobuf:"bytes,1,opt,name=Key,proto3" json:"Key,omitempty"` - Value []byte `protobuf:"bytes,2,opt,name=Value,proto3" json:"Value,omitempty"` + Key []byte `protobuf:"bytes,1,opt,name=Key,proto3" json:"Key,omitempty"` + Value []byte `protobuf:"bytes,2,opt,name=Value,proto3" json:"Value,omitempty"` + Version uint32 `protobuf:"varint,3,opt,name=Version,proto3" json:"Version,omitempty"` } func (m *CollapsedLn) Reset() { *m = CollapsedLn{} } @@ -159,6 +176,13 @@ func (m *CollapsedLn) GetValue() []byte { return nil } +func (m *CollapsedLn) GetVersion() uint32 { + if m != nil { + return m.Version + } + return 0 +} + func init() { proto.RegisterType((*CollapsedBn)(nil), "proto.CollapsedBn") proto.RegisterType((*CollapsedEn)(nil), "proto.CollapsedEn") @@ -168,23 +192,25 @@ func init() { func init() { proto.RegisterFile("node.proto", fileDescriptor_0c843d59d2d938e7) } var fileDescriptor_0c843d59d2d938e7 = []byte{ - // 245 bytes of a gzipped FileDescriptorProto + // 283 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xca, 0xcb, 0x4f, 0x49, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x53, 0x52, 0xba, 0xe9, 0x99, 0x25, 0x19, 0xa5, 0x49, 0x7a, 0xc9, 0xf9, 0xb9, 0xfa, 0xe9, 0xf9, 0xe9, 0xf9, 0xfa, 0x60, 0xe1, 0xa4, 0xd2, - 0x34, 0x30, 0x0f, 0xcc, 0x01, 0xb3, 0x20, 0xba, 0x94, 0x6c, 0xb9, 0xb8, 0x9d, 0xf3, 0x73, 0x72, + 0x34, 0x30, 0x0f, 0xcc, 0x01, 0xb3, 0x20, 0xba, 0x94, 0x72, 0xb9, 0xb8, 0x9d, 0xf3, 0x73, 0x72, 0x12, 0x0b, 0x8a, 0x53, 0x53, 0x9c, 0xf2, 0x84, 0xf4, 0xb8, 0xf8, 0x5d, 0xf3, 0x92, 0xf3, 0x53, 0x52, 0x53, 0x9c, 0x33, 0x32, 0x73, 0x52, 0x8a, 0x52, 0xf3, 0x24, 0x18, 0x15, 0x98, 0x35, 0x78, - 0x9c, 0x58, 0x4e, 0xdc, 0x93, 0x67, 0x0c, 0x42, 0x97, 0x54, 0x72, 0x46, 0xd2, 0xee, 0x9a, 0x27, - 0x24, 0xc0, 0xc5, 0xec, 0x9d, 0x5a, 0x29, 0xc1, 0xa8, 0xc0, 0xa8, 0xc1, 0x13, 0x04, 0x62, 0x0a, - 0x29, 0x71, 0xf1, 0x20, 0xeb, 0x91, 0x60, 0x02, 0x4b, 0xa1, 0x88, 0x29, 0x99, 0x22, 0x19, 0xe2, - 0x83, 0xcd, 0x10, 0x11, 0x2e, 0xd6, 0xb0, 0xc4, 0x9c, 0xd2, 0x54, 0xa8, 0x6e, 0x08, 0xc7, 0xc9, - 0xee, 0xc2, 0x43, 0x39, 0x86, 0x1b, 0x0f, 0xe5, 0x18, 0x3e, 0x3c, 0x94, 0x63, 0x6c, 0x78, 0x24, - 0xc7, 0xb8, 0xe2, 0x91, 0x1c, 0xe3, 0x89, 0x47, 0x72, 0x8c, 0x17, 0x1e, 0xc9, 0x31, 0xde, 0x78, - 0x24, 0xc7, 0xf8, 0xe0, 0x91, 0x1c, 0xe3, 0x8b, 0x47, 0x72, 0x0c, 0x1f, 0x1e, 0xc9, 0x31, 0x4e, - 0x78, 0x2c, 0xc7, 0x70, 0xe1, 0xb1, 0x1c, 0xc3, 0x8d, 0xc7, 0x72, 0x0c, 0x51, 0x2c, 0x25, 0x45, - 0x99, 0xa9, 0x49, 0x6c, 0xe0, 0x10, 0x30, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x2c, 0xf8, 0x0f, - 0x28, 0x45, 0x01, 0x00, 0x00, + 0x9c, 0x58, 0x4e, 0xdc, 0x93, 0x67, 0x0c, 0x42, 0x97, 0x04, 0xa9, 0x87, 0xb1, 0xc3, 0x52, 0x8b, + 0x8a, 0x33, 0xf3, 0xf3, 0x24, 0x98, 0x14, 0x18, 0x11, 0xea, 0xd1, 0x24, 0x95, 0xd2, 0x91, 0xac, + 0x73, 0xcd, 0x13, 0x12, 0xe0, 0x62, 0xf6, 0x4e, 0xad, 0x94, 0x60, 0x04, 0x69, 0x09, 0x02, 0x31, + 0x85, 0x94, 0xb8, 0x78, 0x90, 0xed, 0x80, 0x98, 0x16, 0x84, 0x22, 0x06, 0x52, 0x03, 0x66, 0xc0, + 0x6c, 0x64, 0x56, 0x60, 0xd4, 0xe0, 0x0d, 0x42, 0x11, 0x53, 0xf2, 0x47, 0xb2, 0xc8, 0x07, 0x9b, + 0x45, 0x22, 0x5c, 0xac, 0x61, 0x89, 0x39, 0xa5, 0xa9, 0x50, 0x1b, 0x20, 0x1c, 0x21, 0x09, 0x2e, + 0x76, 0x54, 0x53, 0x61, 0x5c, 0x27, 0xbb, 0x0b, 0x0f, 0xe5, 0x18, 0x6e, 0x3c, 0x94, 0x63, 0xf8, + 0xf0, 0x50, 0x8e, 0xb1, 0xe1, 0x91, 0x1c, 0xe3, 0x8a, 0x47, 0x72, 0x8c, 0x27, 0x1e, 0xc9, 0x31, + 0x5e, 0x78, 0x24, 0xc7, 0x78, 0xe3, 0x91, 0x1c, 0xe3, 0x83, 0x47, 0x72, 0x8c, 0x2f, 0x1e, 0xc9, + 0x31, 0x7c, 0x78, 0x24, 0xc7, 0x38, 0xe1, 0xb1, 0x1c, 0xc3, 0x85, 0xc7, 0x72, 0x0c, 0x37, 0x1e, + 0xcb, 0x31, 0x44, 0xb1, 0x94, 0x14, 0x65, 0xa6, 0x26, 0xb1, 0x81, 0xc3, 0xdb, 0x18, 0x10, 0x00, + 0x00, 0xff, 0xff, 0x62, 0x19, 0x3f, 0xae, 0xb3, 0x01, 0x00, 0x00, } func (this *CollapsedBn) Equal(that interface{}) bool { @@ -214,6 +240,9 @@ func (this *CollapsedBn) Equal(that interface{}) bool { return false } } + if !bytes.Equal(this.ChildrenVersion, that1.ChildrenVersion) { + return false + } return true } func (this *CollapsedEn) Equal(that interface{}) bool { @@ -241,6 +270,9 @@ func (this *CollapsedEn) Equal(that interface{}) bool { if !bytes.Equal(this.EncodedChild, that1.EncodedChild) { return false } + if this.ChildVersion != that1.ChildVersion { + return false + } return true } func (this *CollapsedLn) Equal(that interface{}) bool { @@ -268,15 +300,19 @@ func (this *CollapsedLn) Equal(that interface{}) bool { if !bytes.Equal(this.Value, that1.Value) { return false } + if this.Version != that1.Version { + return false + } return true } func (this *CollapsedBn) GoString() string { if this == nil { return "nil" } - s := make([]string, 0, 5) + s := make([]string, 0, 6) s = append(s, "&trie.CollapsedBn{") s = append(s, "EncodedChildren: "+fmt.Sprintf("%#v", this.EncodedChildren)+",\n") + s = append(s, "ChildrenVersion: "+fmt.Sprintf("%#v", this.ChildrenVersion)+",\n") s = append(s, "}") return strings.Join(s, "") } @@ -284,10 +320,11 @@ func (this *CollapsedEn) GoString() string { if this == nil { return "nil" } - s := make([]string, 0, 6) + s := make([]string, 0, 7) s = append(s, "&trie.CollapsedEn{") s = append(s, "Key: "+fmt.Sprintf("%#v", this.Key)+",\n") s = append(s, "EncodedChild: "+fmt.Sprintf("%#v", this.EncodedChild)+",\n") + s = append(s, "ChildVersion: "+fmt.Sprintf("%#v", this.ChildVersion)+",\n") s = append(s, "}") return strings.Join(s, "") } @@ -295,10 +332,11 @@ func (this *CollapsedLn) GoString() string { if this == nil { return "nil" } - s := make([]string, 0, 6) + s := make([]string, 0, 7) s = append(s, "&trie.CollapsedLn{") s = append(s, "Key: "+fmt.Sprintf("%#v", this.Key)+",\n") s = append(s, "Value: "+fmt.Sprintf("%#v", this.Value)+",\n") + s = append(s, "Version: "+fmt.Sprintf("%#v", this.Version)+",\n") s = append(s, "}") return strings.Join(s, "") } @@ -330,6 +368,13 @@ func (m *CollapsedBn) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if len(m.ChildrenVersion) > 0 { + i -= len(m.ChildrenVersion) + copy(dAtA[i:], m.ChildrenVersion) + i = encodeVarintNode(dAtA, i, uint64(len(m.ChildrenVersion))) + i-- + dAtA[i] = 0x12 + } if len(m.EncodedChildren) > 0 { for iNdEx := len(m.EncodedChildren) - 1; iNdEx >= 0; iNdEx-- { i -= len(m.EncodedChildren[iNdEx]) @@ -362,6 +407,11 @@ func (m *CollapsedEn) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if m.ChildVersion != 0 { + i = encodeVarintNode(dAtA, i, uint64(m.ChildVersion)) + i-- + dAtA[i] = 0x18 + } if len(m.EncodedChild) > 0 { i -= len(m.EncodedChild) copy(dAtA[i:], m.EncodedChild) @@ -399,6 +449,11 @@ func (m *CollapsedLn) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if m.Version != 0 { + i = encodeVarintNode(dAtA, i, uint64(m.Version)) + i-- + dAtA[i] = 0x18 + } if len(m.Value) > 0 { i -= len(m.Value) copy(dAtA[i:], m.Value) @@ -439,6 +494,10 @@ func (m *CollapsedBn) Size() (n int) { n += 1 + l + sovNode(uint64(l)) } } + l = len(m.ChildrenVersion) + if l > 0 { + n += 1 + l + sovNode(uint64(l)) + } return n } @@ -456,6 +515,9 @@ func (m *CollapsedEn) Size() (n int) { if l > 0 { n += 1 + l + sovNode(uint64(l)) } + if m.ChildVersion != 0 { + n += 1 + sovNode(uint64(m.ChildVersion)) + } return n } @@ -473,6 +535,9 @@ func (m *CollapsedLn) Size() (n int) { if l > 0 { n += 1 + l + sovNode(uint64(l)) } + if m.Version != 0 { + n += 1 + sovNode(uint64(m.Version)) + } return n } @@ -488,6 +553,7 @@ func (this *CollapsedBn) String() string { } s := strings.Join([]string{`&CollapsedBn{`, `EncodedChildren:` + fmt.Sprintf("%v", this.EncodedChildren) + `,`, + `ChildrenVersion:` + fmt.Sprintf("%v", this.ChildrenVersion) + `,`, `}`, }, "") return s @@ -499,6 +565,7 @@ func (this *CollapsedEn) String() string { s := strings.Join([]string{`&CollapsedEn{`, `Key:` + fmt.Sprintf("%v", this.Key) + `,`, `EncodedChild:` + fmt.Sprintf("%v", this.EncodedChild) + `,`, + `ChildVersion:` + fmt.Sprintf("%v", this.ChildVersion) + `,`, `}`, }, "") return s @@ -510,6 +577,7 @@ func (this *CollapsedLn) String() string { s := strings.Join([]string{`&CollapsedLn{`, `Key:` + fmt.Sprintf("%v", this.Key) + `,`, `Value:` + fmt.Sprintf("%v", this.Value) + `,`, + `Version:` + fmt.Sprintf("%v", this.Version) + `,`, `}`, }, "") return s @@ -583,6 +651,40 @@ func (m *CollapsedBn) Unmarshal(dAtA []byte) error { m.EncodedChildren = append(m.EncodedChildren, make([]byte, postIndex-iNdEx)) copy(m.EncodedChildren[len(m.EncodedChildren)-1], dAtA[iNdEx:postIndex]) iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ChildrenVersion", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNode + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthNode + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthNode + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ChildrenVersion = append(m.ChildrenVersion[:0], dAtA[iNdEx:postIndex]...) + if m.ChildrenVersion == nil { + m.ChildrenVersion = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNode(dAtA[iNdEx:]) @@ -704,6 +806,25 @@ func (m *CollapsedEn) Unmarshal(dAtA []byte) error { m.EncodedChild = []byte{} } iNdEx = postIndex + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field ChildVersion", wireType) + } + m.ChildVersion = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNode + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.ChildVersion |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } default: iNdEx = preIndex skippy, err := skipNode(dAtA[iNdEx:]) @@ -825,6 +946,25 @@ func (m *CollapsedLn) Unmarshal(dAtA []byte) error { m.Value = []byte{} } iNdEx = postIndex + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Version", wireType) + } + m.Version = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNode + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Version |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } default: iNdEx = preIndex skippy, err := skipNode(dAtA[iNdEx:]) diff --git a/trie/node.proto b/trie/node.proto index 98008fa4d3b..7ea370c3fac 100644 --- a/trie/node.proto +++ b/trie/node.proto @@ -9,14 +9,17 @@ import "github.com/gogo/protobuf/gogoproto/gogo.proto"; message CollapsedBn{ repeated bytes EncodedChildren = 1 [(gogoproto.nullable) = true]; + bytes ChildrenVersion = 2 [(gogoproto.nullable) = true]; } message CollapsedEn{ bytes Key = 1; bytes EncodedChild = 2; + uint32 ChildVersion = 3; } message CollapsedLn{ bytes Key = 1; bytes Value = 2; + uint32 Version = 3; } diff --git a/trie/node_test.go b/trie/node_test.go index 0b6e850ee63..d73bca88cfb 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" dataMock "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/trie/keyBuilder" logger "github.com/multiversx/mx-chain-logger-go" @@ -99,7 +100,7 @@ func TestNode_encodeNodeAndGetHashLeafNode(t *testing.T) { t.Parallel() marsh, hasher := getTestMarshalizerAndHasher() - ln, _ := newLeafNode([]byte("dog"), []byte("dog"), marsh, hasher) + ln, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), marsh, hasher) encNode, _ := marsh.Marshal(ln) encNode = append(encNode, leaf) @@ -525,7 +526,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesCollapsedTrie(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), tr.root.getHash(), keyBuilder.NewKeyBuilder()) + err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), tr.root.getHash(), keyBuilder.NewKeyBuilder(), parsers.NewMainTrieLeafParser()) assert.Nil(t, err) leaves := make(map[string][]byte) @@ -679,6 +680,491 @@ func TestTreatLogError(t *testing.T) { }) } +func TestNodesVersion_insertInLn(t *testing.T) { + t.Parallel() + + t.Run("insert in same leaf - change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + ln, ok := tr.root.(*leafNode) + assert.True(t, ok) + version, _ := ln.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aab"), core.AutoBalanceEnabled) + version, _ = ln.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("insert in leaf - create new branch node", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ := bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + + tr, _ = newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + bn, ok = tr.root.(*branchNode) + assert.True(t, ok) + version, _ = bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("insert in leaf - create new extension", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.NotSpecified, version) + + tr, _ = newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + en, ok = tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + +} + +func TestNodesVersion_insertInEn(t *testing.T) { + t.Parallel() + + t.Run("insert in same extension node - change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + en, ok = tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.NotSpecified) + en, ok = tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = en.getVersion() + assert.Equal(t, core.NotSpecified, version) + }) + + t.Run("insert in extension node - create new branch - change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("qqq"), []byte("qqq"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), core.NotSpecified) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ = bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + }) + + t.Run("insert in extension node - create new branch - do not change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("qqq"), []byte("qqq"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), core.AutoBalanceEnabled) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ = bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("insert in extension node - create new branch with following extension node - change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), core.NotSpecified) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ = bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + }) + + t.Run("insert in extension node - create new branch with following extension node - do not change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), core.AutoBalanceEnabled) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ = bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("insert in extension node - create new extension and branch - change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.UpdateWithVersion([]byte("bba"), []byte("bba"), core.NotSpecified) + en, ok = tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = en.getVersion() + assert.Equal(t, core.NotSpecified, version) + }) + + t.Run("insert in extension node - create new extension and branch - do not change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.UpdateWithVersion([]byte("bba"), []byte("bba"), core.AutoBalanceEnabled) + en, ok = tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) +} + +func TestNodesVersion_insertInBn(t *testing.T) { + t.Parallel() + + t.Run("insert in branch node on nil child - same version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ := bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), core.AutoBalanceEnabled) + version, _ = bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("insert in branch node on nil child - change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ := bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), core.NotSpecified) + version, _ = bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + }) + + t.Run("insert in branch node on existing child - same version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ := bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aab"), core.AutoBalanceEnabled) + version, _ = bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("insert in branch node on existing child - change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ := bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aab"), core.AutoBalanceEnabled) + version, _ = bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) +} + +func TestNodesVersion_deleteFromEn(t *testing.T) { + t.Parallel() + + t.Run("new child is leaf node - change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.Delete([]byte("aaa")) + ln, ok := tr.root.(*leafNode) + assert.True(t, ok) + version, _ = ln.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("new child is leaf node - same version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.Delete([]byte("baa")) + ln, ok := tr.root.(*leafNode) + assert.True(t, ok) + version, _ = ln.getVersion() + assert.Equal(t, core.NotSpecified, version) + }) + + t.Run("new child is extension node - same version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("zza"), []byte("zza"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.Delete([]byte("zza")) + en, ok = tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("new child is extension node - change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("zza"), []byte("zza"), core.NotSpecified) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.Delete([]byte("zza")) + en, ok = tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("new child is branch node - same version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bba"), []byte("baa"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.Delete([]byte("aaa")) + bn, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("new child is branch node - change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bba"), []byte("baa"), core.AutoBalanceEnabled) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ := en.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.Delete([]byte("aaa")) + bn, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) +} + +func TestNodesVersion_deleteFromBn(t *testing.T) { + t.Parallel() + + t.Run("delete leaf - branch does not reduce - bn should not change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), core.AutoBalanceEnabled) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ := bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + + _ = tr.Delete([]byte("aaa")) + bn, ok = tr.root.(*branchNode) + assert.True(t, ok) + version, _ = bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("delete leaf - branch does not reduce - bn should change version", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), core.AutoBalanceEnabled) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ := bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.Delete([]byte("aaa")) + bn, ok = tr.root.(*branchNode) + assert.True(t, ok) + version, _ = bn.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("branch with branch child is reduced", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("qqq"), []byte("bbb"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("ccc"), core.NotSpecified) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ := bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.Delete([]byte("zzz")) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("branch with extension child is reduced", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bba"), []byte("bbb"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("ccc"), core.NotSpecified) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ := bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.Delete([]byte("zzz")) + en, ok := tr.root.(*extensionNode) + assert.True(t, ok) + version, _ = en.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) + + t.Run("branch with leaf child is reduced", func(t *testing.T) { + t.Parallel() + + tr, _ := newEmptyTrie() + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + version, _ := bn.getVersion() + assert.Equal(t, core.NotSpecified, version) + + _ = tr.Delete([]byte("aaa")) + ln, ok := tr.root.(*leafNode) + assert.True(t, ok) + version, _ = ln.getVersion() + assert.Equal(t, core.AutoBalanceEnabled, version) + }) +} + func Benchmark_ShouldStopIfContextDoneBlockingIfBusy(b *testing.B) { ctx := context.Background() b.ResetTimer() diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index fd1e41aca66..485b01bf199 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -13,8 +13,11 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/statistics" logger "github.com/multiversx/mx-chain-logger-go" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) var log = logger.GetOrCreate("trie") @@ -32,10 +35,12 @@ const rootDepthLevel = 0 type patriciaMerkleTrie struct { root node - trieStorage common.StorageManager - marshalizer marshal.Marshalizer - hasher hashing.Hasher - mutOperation sync.RWMutex + trieStorage common.StorageManager + marshalizer marshal.Marshalizer + hasher hashing.Hasher + enableEpochsHandler common.EnableEpochsHandler + trieNodeVersionVerifier core.TrieNodeVersionVerifier + mutOperation sync.RWMutex oldHashes [][]byte oldRoot []byte @@ -48,6 +53,7 @@ func NewTrie( trieStorage common.StorageManager, msh marshal.Marshalizer, hsh hashing.Hasher, + enableEpochsHandler common.EnableEpochsHandler, maxTrieLevelInMemory uint, ) (*patriciaMerkleTrie, error) { if check.IfNil(trieStorage) { @@ -59,19 +65,29 @@ func NewTrie( if check.IfNil(hsh) { return nil, ErrNilHasher } + if check.IfNil(enableEpochsHandler) { + return nil, errors.ErrNilEnableEpochsHandler + } if maxTrieLevelInMemory == 0 { return nil, ErrInvalidLevelValue } log.Trace("created new trie", "max trie level in memory", maxTrieLevelInMemory) + tnvv, err := core.NewTrieNodeVersionVerifier(enableEpochsHandler) + if err != nil { + return nil, err + } + return &patriciaMerkleTrie{ - trieStorage: trieStorage, - marshalizer: msh, - hasher: hsh, - oldHashes: make([][]byte, 0), - oldRoot: make([]byte, 0), - maxTrieLevelInMemory: maxTrieLevelInMemory, - chanClose: make(chan struct{}), + trieStorage: trieStorage, + marshalizer: msh, + hasher: hsh, + oldHashes: make([][]byte, 0), + oldRoot: make([]byte, 0), + maxTrieLevelInMemory: maxTrieLevelInMemory, + chanClose: make(chan struct{}), + enableEpochsHandler: enableEpochsHandler, + trieNodeVersionVerifier: tnvv, }, nil } @@ -102,19 +118,39 @@ func (tr *patriciaMerkleTrie) Update(key, value []byte) error { tr.mutOperation.Lock() defer tr.mutOperation.Unlock() - log.Trace("update trie", "key", hex.EncodeToString(key), "val", hex.EncodeToString(value)) + log.Trace("update trie", + "key", hex.EncodeToString(key), + "val", hex.EncodeToString(value), + ) - hexKey := keyBytesToHex(key) - newLn, err := newLeafNode(hexKey, value, tr.marshalizer, tr.hasher) - if err != nil { - return err - } + return tr.update(key, value, core.NotSpecified) +} - var newRoot node - var oldHashes [][]byte +// UpdateWithVersion does the same thing as Update, but the new leaf that is created will be of the specified version +func (tr *patriciaMerkleTrie) UpdateWithVersion(key []byte, value []byte, version core.TrieNodeVersion) error { + tr.mutOperation.Lock() + defer tr.mutOperation.Unlock() + + log.Trace("update trie with version", + "key", hex.EncodeToString(key), + "val", hex.EncodeToString(value), + "version", version, + ) + + return tr.update(key, value, version) +} + +func (tr *patriciaMerkleTrie) update(key []byte, value []byte, version core.TrieNodeVersion) error { + hexKey := keyBytesToHex(key) if len(value) != 0 { + newData := core.TrieData{ + Key: hexKey, + Value: value, + Version: version, + } + if tr.root == nil { - newRoot, err = newLeafNode(hexKey, value, tr.marshalizer, tr.hasher) + newRoot, err := newLeafNode(newData, tr.marshalizer, tr.hasher) if err != nil { return err } @@ -127,7 +163,7 @@ func (tr *patriciaMerkleTrie) Update(key, value []byte) error { tr.oldRoot = tr.root.getHash() } - newRoot, oldHashes, err = tr.root.insert(newLn, tr.trieStorage) + newRoot, oldHashes, err := tr.root.insert(newData, tr.trieStorage) if err != nil { return err } @@ -141,22 +177,7 @@ func (tr *patriciaMerkleTrie) Update(key, value []byte) error { logArrayWithTrace("oldHashes after insert", "hash", oldHashes) } else { - if tr.root == nil { - return nil - } - - if !tr.root.isDirty() { - tr.oldRoot = tr.root.getHash() - } - - _, newRoot, oldHashes, err = tr.root.delete(hexKey, tr.trieStorage) - if err != nil { - return err - } - tr.root = newRoot - tr.oldHashes = append(tr.oldHashes, oldHashes...) - - logArrayWithTrace("oldHashes after delete", "hash", oldHashes) + return tr.delete(hexKey) } return nil @@ -168,6 +189,10 @@ func (tr *patriciaMerkleTrie) Delete(key []byte) error { defer tr.mutOperation.Unlock() hexKey := keyBytesToHex(key) + return tr.delete(hexKey) +} + +func (tr *patriciaMerkleTrie) delete(hexKey []byte) error { if tr.root == nil { return nil } @@ -182,6 +207,7 @@ func (tr *patriciaMerkleTrie) Delete(key []byte) error { } tr.root = newRoot tr.oldHashes = append(tr.oldHashes, oldHashes...) + logArrayWithTrace("oldHashes after delete", "hash", oldHashes) return nil } @@ -272,6 +298,7 @@ func (tr *patriciaMerkleTrie) recreate(root []byte, tsm common.StorageManager) ( tr.trieStorage, tr.marshalizer, tr.hasher, + tr.enableEpochsHandler, tr.maxTrieLevelInMemory, ) } @@ -352,6 +379,7 @@ func (tr *patriciaMerkleTrie) recreateFromDb(rootHash []byte, tsm common.Storage tsm, tr.marshalizer, tr.hasher, + tr.enableEpochsHandler, tr.maxTrieLevelInMemory, ) if err != nil { @@ -429,6 +457,7 @@ func (tr *patriciaMerkleTrie) GetAllLeavesOnChannel( ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, + trieLeafParser common.TrieLeafParser, ) error { if leavesChannels == nil { return ErrNilTrieIteratorChannels @@ -439,6 +468,12 @@ func (tr *patriciaMerkleTrie) GetAllLeavesOnChannel( if leavesChannels.ErrChan == nil { return ErrNilTrieIteratorErrChannel } + if check.IfNil(keyBuilder) { + return ErrNilKeyBuilder + } + if check.IfNil(trieLeafParser) { + return ErrNilTrieLeafParser + } newTrie, err := tr.recreate(rootHash, tr.trieStorage) if err != nil { @@ -459,6 +494,7 @@ func (tr *patriciaMerkleTrie) GetAllLeavesOnChannel( err = newTrie.root.getAllLeavesOnChannel( leavesChannels.LeavesChan, keyBuilder, + trieLeafParser, tr.trieStorage, tr.marshalizer, tr.chanClose, @@ -556,9 +592,21 @@ func (tr *patriciaMerkleTrie) GetProof(key []byte) ([][]byte, []byte, error) { // VerifyProof verifies the given Merkle proof func (tr *patriciaMerkleTrie) VerifyProof(rootHash []byte, key []byte, proof [][]byte) (bool, error) { - tr.mutOperation.Lock() - defer tr.mutOperation.Unlock() + tr.mutOperation.RLock() + defer tr.mutOperation.RUnlock() + + ok, err := tr.verifyProof(rootHash, tr.hasher.Compute(string(key)), proof) + if err != nil { + return false, err + } + if ok { + return true, nil + } + return tr.verifyProof(rootHash, key, proof) +} + +func (tr *patriciaMerkleTrie) verifyProof(rootHash []byte, key []byte, proof [][]byte) (bool, error) { wantHash := rootHash key = keyBytesToHex(key) for _, encodedNode := range proof { @@ -600,7 +648,7 @@ func (tr *patriciaMerkleTrie) GetOldRoot() []byte { } // GetTrieStats will collect and return the statistics for the given rootHash -func (tr *patriciaMerkleTrie) GetTrieStats(address string, rootHash []byte) (*statistics.TrieStatsDTO, error) { +func (tr *patriciaMerkleTrie) GetTrieStats(address string, rootHash []byte) (common.TrieStatisticsHandler, error) { newTrie, err := tr.recreate(rootHash, tr.trieStorage) if err != nil { return nil, err @@ -613,7 +661,67 @@ func (tr *patriciaMerkleTrie) GetTrieStats(address string, rootHash []byte) (*st } ts.AddAccountInfo(address, rootHash) - return ts.GetTrieStats(), nil + return ts, nil +} + +// CollectLeavesForMigration will collect trie leaves that need to be migrated. The leaves are collected in the trieMigrator. +// The traversing of the trie is done in a DFS manner, and it will stop when the gas runs out (this will be signaled by the trieMigrator). +func (tr *patriciaMerkleTrie) CollectLeavesForMigration(args vmcommon.ArgsMigrateDataTrieLeaves) error { + tr.mutOperation.Lock() + defer tr.mutOperation.Unlock() + + if check.IfNil(tr.root) { + return nil + } + if check.IfNil(args.TrieMigrator) { + return errors.ErrNilTrieMigrator + } + + err := tr.checkIfMigrationPossible(args) + if err != nil { + return err + } + + _, err = tr.root.collectLeavesForMigration(args, tr.trieStorage, keyBuilder.NewKeyBuilder()) + if err != nil { + return err + } + + return nil +} + +func (tr *patriciaMerkleTrie) checkIfMigrationPossible(args vmcommon.ArgsMigrateDataTrieLeaves) error { + if !tr.trieNodeVersionVerifier.IsValidVersion(args.NewVersion) { + return fmt.Errorf("%w: newVersion %v", errors.ErrInvalidTrieNodeVersion, args.NewVersion) + } + + if !tr.trieNodeVersionVerifier.IsValidVersion(args.OldVersion) { + return fmt.Errorf("%w: oldVersion %v", errors.ErrInvalidTrieNodeVersion, args.OldVersion) + } + + if args.NewVersion == core.NotSpecified && args.OldVersion == core.AutoBalanceEnabled { + return fmt.Errorf("%w: cannot migrate from %v to %v", errors.ErrInvalidTrieNodeVersion, core.AutoBalanceEnabled, core.NotSpecified) + } + + return nil +} + +// IsMigratedToLatestVersion returns true if the trie is migrated to the latest version +func (tr *patriciaMerkleTrie) IsMigratedToLatestVersion() (bool, error) { + tr.mutOperation.Lock() + defer tr.mutOperation.Unlock() + + if check.IfNil(tr.root) { + return true, nil + } + + version, err := tr.root.getVersion() + if err != nil { + return false, err + } + + versionForNewlyAddedData := core.GetVersionForNewData(tr.enableEpochsHandler) + return version == versionForNewlyAddedData, nil } // Close stops all the active goroutines started by the trie diff --git a/trie/patriciaMerkleTrie_test.go b/trie/patriciaMerkleTrie_test.go index 384d6891b8f..900d1b66002 100644 --- a/trie/patriciaMerkleTrie_test.go +++ b/trie/patriciaMerkleTrie_test.go @@ -7,6 +7,7 @@ import ( "fmt" "math/rand" "strconv" + "strings" "sync" "testing" "time" @@ -18,11 +19,15 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/holders" + errorsCommon "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/state/parsers" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/mock" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -35,12 +40,19 @@ func emptyTrie() common.Trie { return tr } -func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, uint) { +func emptyTrieWithCustomEnableEpochsHandler(handler common.EnableEpochsHandler) common.Trie { + storage, marshaller, hasher, _, maxTrieLevelInMem := getDefaultTrieParameters() + + tr, _ := trie.NewTrie(storage, marshaller, hasher, handler, maxTrieLevelInMem) + return tr +} + +func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, common.EnableEpochsHandler, uint) { args := trie.GetDefaultTrieStorageManagerParameters() trieStorageManager, _ := trie.NewTrieStorageManager(args) maxTrieLevelInMemory := uint(1) - return trieStorageManager, args.Marshalizer, args.Hasher, maxTrieLevelInMemory + return trieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory } func initTrieMultipleValues(nr int) (common.Trie, [][]byte) { @@ -59,18 +71,22 @@ func initTrieMultipleValues(nr int) (common.Trie, [][]byte) { func initTrie() common.Trie { tr := emptyTrie() + addDefaultDataToTrie(tr) + + return tr +} + +func addDefaultDataToTrie(tr common.Trie) { _ = tr.Update([]byte("doe"), []byte("reindeer")) _ = tr.Update([]byte("dog"), []byte("puppy")) _ = tr.Update([]byte("ddog"), []byte("cat")) - - return tr } func TestNewTrieWithNilTrieStorage(t *testing.T) { t.Parallel() - _, marshalizer, hasher, maxTrieLevelInMemory := getDefaultTrieParameters() - tr, err := trie.NewTrie(nil, marshalizer, hasher, maxTrieLevelInMemory) + _, marshalizer, hasher, enableEpochsHandler, maxTrieLevelInMemory := getDefaultTrieParameters() + tr, err := trie.NewTrie(nil, marshalizer, hasher, enableEpochsHandler, maxTrieLevelInMemory) assert.Nil(t, tr) assert.Equal(t, trie.ErrNilTrieStorage, err) @@ -79,8 +95,8 @@ func TestNewTrieWithNilTrieStorage(t *testing.T) { func TestNewTrieWithNilMarshalizer(t *testing.T) { t.Parallel() - trieStorage, _, hasher, maxTrieLevelInMemory := getDefaultTrieParameters() - tr, err := trie.NewTrie(trieStorage, nil, hasher, maxTrieLevelInMemory) + trieStorage, _, hasher, enableEpochsHandler, maxTrieLevelInMemory := getDefaultTrieParameters() + tr, err := trie.NewTrie(trieStorage, nil, hasher, enableEpochsHandler, maxTrieLevelInMemory) assert.Nil(t, tr) assert.Equal(t, trie.ErrNilMarshalizer, err) @@ -89,18 +105,28 @@ func TestNewTrieWithNilMarshalizer(t *testing.T) { func TestNewTrieWithNilHasher(t *testing.T) { t.Parallel() - trieStorage, marshalizer, _, maxTrieLevelInMemory := getDefaultTrieParameters() - tr, err := trie.NewTrie(trieStorage, marshalizer, nil, maxTrieLevelInMemory) + trieStorage, marshalizer, _, enableEpochsHandler, maxTrieLevelInMemory := getDefaultTrieParameters() + tr, err := trie.NewTrie(trieStorage, marshalizer, nil, enableEpochsHandler, maxTrieLevelInMemory) assert.Nil(t, tr) assert.Equal(t, trie.ErrNilHasher, err) } +func TestNewTrieWithNilEnableEpochsHandler(t *testing.T) { + t.Parallel() + + trieStorage, marshalizer, hasher, _, maxTrieLevelInMemory := getDefaultTrieParameters() + tr, err := trie.NewTrie(trieStorage, marshalizer, hasher, nil, maxTrieLevelInMemory) + + assert.Nil(t, tr) + assert.Equal(t, errorsCommon.ErrNilEnableEpochsHandler, err) +} + func TestNewTrieWithInvalidMaxTrieLevelInMemory(t *testing.T) { t.Parallel() - trieStorage, marshalizer, hasher, _ := getDefaultTrieParameters() - tr, err := trie.NewTrie(trieStorage, marshalizer, hasher, 0) + trieStorage, marshalizer, hasher, enableEpochsHandler, _ := getDefaultTrieParameters() + tr, err := trie.NewTrie(trieStorage, marshalizer, hasher, enableEpochsHandler, 0) assert.Nil(t, tr) assert.Equal(t, trie.ErrInvalidLevelValue, err) @@ -347,6 +373,18 @@ func TestPatriciaMerkleTree_DeleteAfterCommit(t *testing.T) { assert.Equal(t, root2, root1) } +func TestPatriciaMerkleTree_DeleteNotPresent(t *testing.T) { + t.Parallel() + + tr := initTrie() + + err := tr.Commit() + assert.Nil(t, err) + + err = tr.Delete([]byte("adog")) + assert.Nil(t, err) +} + func TestPatriciaMerkleTrie_Recreate(t *testing.T) { t.Parallel() @@ -535,7 +573,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { t.Parallel() tr := emptyTrie() - err := tr.GetAllLeavesOnChannel(nil, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder()) + err := tr.GetAllLeavesOnChannel(nil, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder(), parsers.NewMainTrieLeafParser()) assert.Equal(t, trie.ErrNilTrieIteratorChannels, err) }) @@ -548,7 +586,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { LeavesChan: nil, ErrChan: errChan.NewErrChanWrapper(), } - err := tr.GetAllLeavesOnChannel(iteratorChannels, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder()) + err := tr.GetAllLeavesOnChannel(iteratorChannels, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder(), parsers.NewMainTrieLeafParser()) assert.Equal(t, trie.ErrNilTrieIteratorLeavesChannel, err) }) @@ -561,10 +599,36 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: nil, } - err := tr.GetAllLeavesOnChannel(iteratorChannels, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder()) + err := tr.GetAllLeavesOnChannel(iteratorChannels, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder(), parsers.NewMainTrieLeafParser()) assert.Equal(t, trie.ErrNilTrieIteratorErrChannel, err) }) + t.Run("nil keyBuilder", func(t *testing.T) { + t.Parallel() + + tr := emptyTrie() + + iteratorChannels := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), + ErrChan: errChan.NewErrChanWrapper(), + } + err := tr.GetAllLeavesOnChannel(iteratorChannels, context.Background(), []byte{}, nil, parsers.NewMainTrieLeafParser()) + assert.Equal(t, trie.ErrNilKeyBuilder, err) + }) + + t.Run("nil trieLeafParser", func(t *testing.T) { + t.Parallel() + + tr := emptyTrie() + + iteratorChannels := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), + ErrChan: errChan.NewErrChanWrapper(), + } + err := tr.GetAllLeavesOnChannel(iteratorChannels, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder(), nil) + assert.Equal(t, trie.ErrNilTrieLeafParser, err) + }) + t.Run("empty trie", func(t *testing.T) { t.Parallel() @@ -574,7 +638,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder()) + err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder(), parsers.NewMainTrieLeafParser()) assert.Nil(t, err) assert.NotNil(t, leavesChannel) @@ -606,7 +670,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { return keyBuilderStub } - err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilderStub) + err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilderStub, parsers.NewMainTrieLeafParser()) assert.Nil(t, err) assert.NotNil(t, leavesChannel) @@ -648,7 +712,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { return keyBuilderStub } - err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilderStub) + err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilderStub, parsers.NewMainTrieLeafParser()) assert.Nil(t, err) assert.NotNil(t, leavesChannel) @@ -681,7 +745,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) + err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilder.NewKeyBuilder(), parsers.NewMainTrieLeafParser()) assert.Nil(t, err) assert.NotNil(t, leavesChannel) @@ -886,14 +950,11 @@ func TestPatriciaMerkleTrie_GetTrieStats(t *testing.T) { stats, err := ts.GetTrieStats(address, rootHash) assert.Nil(t, err) - assert.Equal(t, rootHash, stats.RootHash) - assert.Equal(t, address, stats.Address) - - assert.Equal(t, uint64(2), stats.NumBranchNodes) - assert.Equal(t, uint64(1), stats.NumExtensionNodes) - assert.Equal(t, uint64(3), stats.NumLeafNodes) - assert.Equal(t, uint64(6), stats.TotalNumNodes) - assert.Equal(t, uint32(3), stats.MaxTrieDepth) + assert.Equal(t, uint64(2), stats.GetNumBranchNodes()) + assert.Equal(t, uint64(1), stats.GetNumExtensionNodes()) + assert.Equal(t, uint64(3), stats.GetNumLeafNodes()) + assert.Equal(t, uint64(6), stats.GetTotalNumNodes()) + assert.Equal(t, uint32(3), stats.GetMaxTrieDepth()) } func TestPatriciaMerkleTrie_GetOldRoot(t *testing.T) { @@ -993,6 +1054,7 @@ func TestPatriciaMerkleTrie_ConcurrentOperations(t *testing.T) { context.Background(), initialRootHash, keyBuilder.NewKeyBuilder(), + parsers.NewMainTrieLeafParser(), ) assert.Nil(t, err) case 13: @@ -1036,7 +1098,7 @@ func TestPatriciaMerkleTrie_GetSerializedNodesClose(t *testing.T) { } trieStorageManager, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorageManager, args.Marshalizer, args.Hasher, 5) + tr, _ := trie.NewTrie(trieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) numGoRoutines := 1000 wgStart := sync.WaitGroup{} wgStart.Add(numGoRoutines) @@ -1086,6 +1148,406 @@ func TestPatriciaMerkleTrie_GetSerializedNodesClose(t *testing.T) { } } +type dataTrie interface { + CollectLeavesForMigration(args vmcommon.ArgsMigrateDataTrieLeaves) error + UpdateWithVersion(key []byte, value []byte, version core.TrieNodeVersion) error +} + +func TestPatriciaMerkleTrie_CollectLeavesForMigration(t *testing.T) { + t.Parallel() + + t.Run("nil root", func(t *testing.T) { + t.Parallel() + + tr := emptyTrie() + + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + assert.Fail(t, "should not have called this function") + return false + }, + } + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: dtm, + } + err := tr.(dataTrie).CollectLeavesForMigration(args) + assert.Nil(t, err) + }) + + t.Run("nil trie migrator", func(t *testing.T) { + t.Parallel() + + tr := initTrie().(dataTrie) + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: nil, + } + err := tr.CollectLeavesForMigration(args) + assert.Equal(t, errorsCommon.ErrNilTrieMigrator, err) + }) + + t.Run("data trie already migrated", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + dtr := tr.(dataTrie) + _ = dtr.UpdateWithVersion([]byte("dog"), []byte("reindeer"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("ddog"), []byte("puppy"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("doe"), []byte("cat"), core.AutoBalanceEnabled) + + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + } + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: dtm, + } + err := dtr.CollectLeavesForMigration(args) + assert.Nil(t, err) + assert.Equal(t, 1, numLoadsCalled) + }) + + t.Run("trie partially migrated", func(t *testing.T) { + t.Parallel() + + addLeafToMigrationQueueCalled := 0 + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + dtr := tr.(dataTrie) + key := []byte("dog") + value := []byte("reindeer") + _ = dtr.UpdateWithVersion(key, value, core.NotSpecified) + _ = dtr.UpdateWithVersion([]byte("ddog"), []byte("puppy"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("doe"), []byte("cat"), core.AutoBalanceEnabled) + + dtm := &trieMock.DataTrieMigratorStub{ + AddLeafToMigrationQueueCalled: func(leafData core.TrieData, newLeafVersion core.TrieNodeVersion) (bool, error) { + assert.Equal(t, core.AutoBalanceEnabled, newLeafVersion) + assert.Equal(t, key, leafData.Key) + assert.Equal(t, value, leafData.Value) + assert.Equal(t, core.NotSpecified, leafData.Version) + addLeafToMigrationQueueCalled++ + return true, nil + }, + } + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: dtm, + } + err := dtr.CollectLeavesForMigration(args) + assert.Nil(t, err) + assert.Equal(t, 1, addLeafToMigrationQueueCalled) + }) + + t.Run("not enough gas to load the whole trie", func(t *testing.T) { + t.Parallel() + + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + addDefaultDataToTrie(tr) + + dtr := tr.(dataTrie) + numLoads := 0 + numAddLeafToMigrationQueueCalled := 0 + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + if numLoads < 2 { + numLoads++ + return true + } + + numLoads++ + return false + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: dtm, + } + err := dtr.CollectLeavesForMigration(args) + assert.Nil(t, err) + assert.Equal(t, 3, numLoads) + assert.Equal(t, 1, numAddLeafToMigrationQueueCalled) + }) + + t.Run("not enough gas to migrate the whole trie", func(t *testing.T) { + t.Parallel() + + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + addDefaultDataToTrie(tr) + dtr := tr.(dataTrie) + numLoads := 0 + numAddLeafToMigrationQueueCalled := 0 + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoads++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + if numAddLeafToMigrationQueueCalled < 1 { + numAddLeafToMigrationQueueCalled++ + return true, nil + } + + numAddLeafToMigrationQueueCalled++ + return false, nil + }, + } + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: dtm, + } + err := dtr.CollectLeavesForMigration(args) + assert.Nil(t, err) + assert.Equal(t, 5, numLoads) + assert.Equal(t, 2, numAddLeafToMigrationQueueCalled) + }) + + t.Run("migrate to non existent version", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + numAddLeafToMigrationQueueCalled := 0 + dtr := initTrie().(dataTrie) + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.TrieNodeVersion(100), + TrieMigrator: dtm, + } + err := dtr.CollectLeavesForMigration(args) + assert.True(t, strings.Contains(err.Error(), errorsCommon.ErrInvalidTrieNodeVersion.Error())) + assert.Equal(t, 0, numLoadsCalled) + assert.Equal(t, 0, numAddLeafToMigrationQueueCalled) + }) + + t.Run("migrate from non existent version", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + numAddLeafToMigrationQueueCalled := 0 + dtr := initTrie().(dataTrie) + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.TrieNodeVersion(100), + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: dtm, + } + err := dtr.CollectLeavesForMigration(args) + assert.True(t, strings.Contains(err.Error(), errorsCommon.ErrInvalidTrieNodeVersion.Error())) + assert.Equal(t, 0, numLoadsCalled) + assert.Equal(t, 0, numAddLeafToMigrationQueueCalled) + }) + + t.Run("migrate collapsed trie", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + numAddLeafToMigrationQueueCalled := 0 + + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + addDefaultDataToTrie(tr) + _ = tr.Commit() + rootHash, _ := tr.RootHash() + collapsedTrie, _ := tr.Recreate(rootHash) + dtr := collapsedTrie.(dataTrie) + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: dtm, + } + err := dtr.CollectLeavesForMigration(args) + assert.Nil(t, err) + assert.Equal(t, 6, numLoadsCalled) + assert.Equal(t, 3, numAddLeafToMigrationQueueCalled) + }) + + t.Run("migrate all non migrated leaves", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + numAddLeafToMigrationQueueCalled := 0 + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + dtr := tr.(dataTrie) + _ = dtr.UpdateWithVersion([]byte("dog"), []byte("reindeer"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("ddog"), []byte("puppy"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("doe"), []byte("cat"), core.NotSpecified) + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: dtm, + } + err := dtr.CollectLeavesForMigration(args) + assert.Nil(t, err) + assert.Equal(t, 2, numLoadsCalled) + assert.Equal(t, 1, numAddLeafToMigrationQueueCalled) + }) + + t.Run("migrate to same version", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + numAddLeafToMigrationQueueCalled := 0 + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + dtr := tr.(dataTrie) + _ = dtr.UpdateWithVersion([]byte("dog"), []byte("reindeer"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("ddog"), []byte("puppy"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("doe"), []byte("cat"), core.AutoBalanceEnabled) + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + args := vmcommon.ArgsMigrateDataTrieLeaves{ + OldVersion: core.NotSpecified, + NewVersion: core.AutoBalanceEnabled, + TrieMigrator: dtm, + } + err := dtr.CollectLeavesForMigration(args) + assert.Nil(t, err) + assert.Equal(t, 1, numLoadsCalled) + assert.Equal(t, 0, numAddLeafToMigrationQueueCalled) + }) +} + +func TestPatriciaMerkleTrie_IsMigrated(t *testing.T) { + t.Parallel() + + t.Run("nil root", func(t *testing.T) { + t.Parallel() + + tr := emptyTrie() + isMigrated, err := tr.IsMigratedToLatestVersion() + assert.True(t, isMigrated) + assert.Nil(t, err) + }) + + t.Run("not migrated", func(t *testing.T) { + t.Parallel() + + tsm, marshaller, hasher, _, maxTrieInMem := getDefaultTrieParameters() + enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tr, _ := trie.NewTrie(tsm, marshaller, hasher, enableEpochs, maxTrieInMem) + + _ = tr.Update([]byte("dog"), []byte("reindeer")) + isMigrated, err := tr.IsMigratedToLatestVersion() + assert.False(t, isMigrated) + assert.Nil(t, err) + }) + + t.Run("migrated", func(t *testing.T) { + t.Parallel() + + tsm, marshaller, hasher, _, maxTrieInMem := getDefaultTrieParameters() + enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tr, _ := trie.NewTrie(tsm, marshaller, hasher, enableEpochs, maxTrieInMem) + + _ = tr.UpdateWithVersion([]byte("dog"), []byte("reindeer"), core.AutoBalanceEnabled) + isMigrated, err := tr.IsMigratedToLatestVersion() + assert.True(t, isMigrated) + assert.Nil(t, err) + }) +} + func BenchmarkPatriciaMerkleTree_Insert(b *testing.B) { tr := emptyTrie() hsh := keccak.NewKeccak() diff --git a/trie/statistics/trieStatistics.go b/trie/statistics/trieStatistics.go index 7d56578cc66..85cc322cca7 100644 --- a/trie/statistics/trieStatistics.go +++ b/trie/statistics/trieStatistics.go @@ -3,8 +3,11 @@ package statistics import ( "encoding/hex" "fmt" + "sort" + "sync" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/common" ) type trieStatistics struct { @@ -15,6 +18,9 @@ type trieStatistics struct { branchNodes *nodesStatistics extensionNodes *nodesStatistics leafNodes *nodesStatistics + migrationStats map[core.TrieNodeVersion]uint64 + + mutex sync.RWMutex } type nodesStatistics struct { @@ -40,26 +46,40 @@ func NewTrieStatistics() *trieStatistics { nodesSize: 0, numNodes: 0, }, + migrationStats: make(map[core.TrieNodeVersion]uint64), } } // AddBranchNode will add the given level and size to the branch nodes statistics func (ts *trieStatistics) AddBranchNode(level int, size uint64) { + ts.mutex.Lock() + defer ts.mutex.Unlock() + ts.collectNodeStatistics(level, size, ts.branchNodes) } // AddExtensionNode will add the given level and size to the extension nodes statistics func (ts *trieStatistics) AddExtensionNode(level int, size uint64) { + ts.mutex.Lock() + defer ts.mutex.Unlock() + ts.collectNodeStatistics(level, size, ts.extensionNodes) } // AddLeafNode will add the given level and size to the leaf nodes statistics -func (ts *trieStatistics) AddLeafNode(level int, size uint64) { +func (ts *trieStatistics) AddLeafNode(level int, size uint64, version core.TrieNodeVersion) { + ts.mutex.Lock() + defer ts.mutex.Unlock() + ts.collectNodeStatistics(level, size, ts.leafNodes) + ts.migrationStats[version]++ } // AddAccountInfo will add the address and rootHash to the collected statistics func (ts *trieStatistics) AddAccountInfo(address string, rootHash []byte) { + ts.mutex.Lock() + defer ts.mutex.Unlock() + ts.address = address ts.rootHash = rootHash } @@ -73,55 +93,157 @@ func (ts *trieStatistics) collectNodeStatistics(level int, size uint64, nodeStat } } -// GetTrieStats returns a DTO that contains all the collected info about the trie -func (ts *trieStatistics) GetTrieStats() *TrieStatsDTO { - totalNodesSize := ts.branchNodes.nodesSize + ts.extensionNodes.nodesSize + ts.leafNodes.nodesSize - totalNumNodes := ts.branchNodes.numNodes + ts.extensionNodes.numNodes + ts.leafNodes.numNodes - - return &TrieStatsDTO{ - Address: ts.address, - RootHash: ts.rootHash, - TotalNodesSize: totalNodesSize, - TotalNumNodes: totalNumNodes, - MaxTrieDepth: ts.maxTrieDepth, - BranchNodesSize: ts.branchNodes.nodesSize, - NumBranchNodes: ts.branchNodes.numNodes, - ExtensionNodesSize: ts.extensionNodes.nodesSize, - NumExtensionNodes: ts.extensionNodes.numNodes, - LeafNodesSize: ts.leafNodes.nodesSize, - NumLeafNodes: ts.leafNodes.numNodes, +// GetTotalNodesSize will return the total size of all nodes +func (ts *trieStatistics) GetTotalNodesSize() uint64 { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + + return ts.getTotalNodesSize() +} + +func (ts *trieStatistics) getTotalNodesSize() uint64 { + return ts.branchNodes.nodesSize + ts.extensionNodes.nodesSize + ts.leafNodes.nodesSize +} + +// GetTotalNumNodes will return the total number of nodes +func (ts *trieStatistics) GetTotalNumNodes() uint64 { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + + return ts.getTotalNumNodes() +} + +func (ts *trieStatistics) getTotalNumNodes() uint64 { + return ts.branchNodes.numNodes + ts.extensionNodes.numNodes + ts.leafNodes.numNodes +} + +// GetMaxTrieDepth will return the maximum trie depth +func (ts *trieStatistics) GetMaxTrieDepth() uint32 { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + + return ts.maxTrieDepth +} + +// GetBranchNodesSize will return the size of all branch nodes +func (ts *trieStatistics) GetBranchNodesSize() uint64 { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + + return ts.branchNodes.nodesSize +} + +// GetNumBranchNodes will return the number of branch nodes +func (ts *trieStatistics) GetNumBranchNodes() uint64 { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + + return ts.branchNodes.numNodes +} + +// GetExtensionNodesSize will return the size of all extension nodes +func (ts *trieStatistics) GetExtensionNodesSize() uint64 { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + + return ts.extensionNodes.nodesSize +} + +// GetNumExtensionNodes will return the number of extension nodes +func (ts *trieStatistics) GetNumExtensionNodes() uint64 { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + + return ts.extensionNodes.numNodes +} + +// GetLeafNodesSize will return the size of all leaf nodes +func (ts *trieStatistics) GetLeafNodesSize() uint64 { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + + return ts.leafNodes.nodesSize +} + +// GetNumLeafNodes will return the number of leaf nodes +func (ts *trieStatistics) GetNumLeafNodes() uint64 { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + + return ts.leafNodes.numNodes +} + +// GetLeavesMigrationStats will return the leaves migration statistics +func (ts *trieStatistics) GetLeavesMigrationStats() map[core.TrieNodeVersion]uint64 { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + + migrationStatsMap := make(map[core.TrieNodeVersion]uint64) + for version, numLeaves := range ts.migrationStats { + migrationStatsMap[version] = numLeaves } + + return migrationStatsMap } -// TrieStatsDTO holds the statistics for the trie -type TrieStatsDTO struct { - Address string - RootHash []byte - TotalNodesSize uint64 - TotalNumNodes uint64 - MaxTrieDepth uint32 +// MergeTriesStatistics will merge the given statistics with the current statistics +func (ts *trieStatistics) MergeTriesStatistics(statsToBeMerged common.TrieStatisticsHandler) { + ts.mutex.Lock() + defer ts.mutex.Unlock() + + if ts.maxTrieDepth < statsToBeMerged.GetMaxTrieDepth() { + ts.maxTrieDepth = statsToBeMerged.GetMaxTrieDepth() + } + + ts.branchNodes.numNodes += statsToBeMerged.GetNumBranchNodes() + ts.branchNodes.nodesSize += statsToBeMerged.GetBranchNodesSize() - BranchNodesSize uint64 - NumBranchNodes uint64 - ExtensionNodesSize uint64 - NumExtensionNodes uint64 - LeafNodesSize uint64 - NumLeafNodes uint64 + ts.extensionNodes.numNodes += statsToBeMerged.GetNumExtensionNodes() + ts.extensionNodes.nodesSize += statsToBeMerged.GetExtensionNodesSize() + + ts.leafNodes.numNodes += statsToBeMerged.GetNumLeafNodes() + ts.leafNodes.nodesSize += statsToBeMerged.GetLeafNodesSize() + + for version, numLeaves := range statsToBeMerged.GetLeavesMigrationStats() { + ts.migrationStats[version] += numLeaves + } +} + +// IsInterfaceNil returns true if there is no value under the interface +func (ts *trieStatistics) IsInterfaceNil() bool { + return ts == nil } // ToString returns the collected statistics as a string array -func (tsd *TrieStatsDTO) ToString() []string { +func (ts *trieStatistics) ToString() []string { + ts.mutex.RLock() + defer ts.mutex.RUnlock() + stats := make([]string, 0) - stats = append(stats, fmt.Sprintf("address %v,", tsd.Address)) - stats = append(stats, fmt.Sprintf("rootHash %v,", hex.EncodeToString(tsd.RootHash))) - stats = append(stats, fmt.Sprintf("total trie size = %v,", core.ConvertBytes(tsd.TotalNodesSize))) - stats = append(stats, fmt.Sprintf("num trie nodes = %v,", tsd.TotalNumNodes)) - stats = append(stats, fmt.Sprintf("max trie depth = %v,", tsd.MaxTrieDepth)) - stats = append(stats, fmt.Sprintf("branch nodes size %v,", core.ConvertBytes(tsd.BranchNodesSize))) - stats = append(stats, fmt.Sprintf("extension nodes size %v,", core.ConvertBytes(tsd.ExtensionNodesSize))) - stats = append(stats, fmt.Sprintf("leaf nodes size %v,", core.ConvertBytes(tsd.LeafNodesSize))) - stats = append(stats, fmt.Sprintf("num branches %v,", tsd.NumBranchNodes)) - stats = append(stats, fmt.Sprintf("num extensions %v,", tsd.NumExtensionNodes)) - stats = append(stats, fmt.Sprintf("num leaves %v", tsd.NumLeafNodes)) + stats = append(stats, fmt.Sprintf("address %v,", ts.address)) + stats = append(stats, fmt.Sprintf("rootHash %v,", hex.EncodeToString(ts.rootHash))) + stats = append(stats, fmt.Sprintf("total trie size = %v,", core.ConvertBytes(ts.getTotalNodesSize()))) + stats = append(stats, fmt.Sprintf("num trie nodes = %v,", ts.getTotalNumNodes())) + stats = append(stats, fmt.Sprintf("max trie depth = %v,", ts.maxTrieDepth)) + stats = append(stats, fmt.Sprintf("branch nodes size %v,", core.ConvertBytes(ts.branchNodes.nodesSize))) + stats = append(stats, fmt.Sprintf("extension nodes size %v,", core.ConvertBytes(ts.extensionNodes.nodesSize))) + stats = append(stats, fmt.Sprintf("leaf nodes size %v,", core.ConvertBytes(ts.leafNodes.nodesSize))) + stats = append(stats, fmt.Sprintf("num branches %v,", ts.branchNodes.numNodes)) + stats = append(stats, fmt.Sprintf("num extensions %v,", ts.extensionNodes.numNodes)) + stats = append(stats, fmt.Sprintf("num leaves %v", ts.leafNodes.numNodes)) + stats = append(stats, getMigrationStatsString(ts.migrationStats)...) + return stats +} + +func getMigrationStatsString(migrationStats map[core.TrieNodeVersion]uint64) []string { + stats := make([]string, 0) + for version, numNodes := range migrationStats { + stats = append(stats, fmt.Sprintf("num leaves with %s version = %v", version, numNodes)) + } + + sort.Slice(stats, func(i, j int) bool { + return stats[i] < stats[j] + }) + return stats } diff --git a/trie/statistics/trieStatisticsCollector.go b/trie/statistics/trieStatisticsCollector.go index af7118bc99e..2e3b444a9af 100644 --- a/trie/statistics/trieStatisticsCollector.go +++ b/trie/statistics/trieStatisticsCollector.go @@ -1,10 +1,15 @@ package statistics import ( + "fmt" + "sort" "strconv" "strings" + "sync" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -13,53 +18,41 @@ var log = logger.GetOrCreate("trieStatistics") const numTriesToPrint = 10 type trieStatisticsCollector struct { - numNodes uint64 - numDataTries uint64 - triesSize uint64 - numTotalLeaves uint64 - numTotalExtensions uint64 - numTotalBranches uint64 - totalSizeLeaves uint64 - totalSizeExtensions uint64 - totalSizeBranches uint64 - triesBySize []*TrieStatsDTO - triesByDepth []*TrieStatsDTO + trieStatsByType map[common.TrieType]common.TrieStatisticsHandler + triesBySize []common.TrieStatisticsHandler + triesByDepth []common.TrieStatisticsHandler + numTriesByType map[common.TrieType]uint64 + + mutex sync.RWMutex } // NewTrieStatisticsCollector creates a new instance of trieStatisticsCollector func NewTrieStatisticsCollector() *trieStatisticsCollector { return &trieStatisticsCollector{ - numNodes: 0, - numDataTries: 0, - triesSize: 0, - numTotalLeaves: 0, - numTotalExtensions: 0, - numTotalBranches: 0, - totalSizeLeaves: 0, - totalSizeExtensions: 0, - totalSizeBranches: 0, - triesBySize: make([]*TrieStatsDTO, numTriesToPrint), - triesByDepth: make([]*TrieStatsDTO, numTriesToPrint), + trieStatsByType: make(map[common.TrieType]common.TrieStatisticsHandler), + triesBySize: make([]common.TrieStatisticsHandler, numTriesToPrint), + triesByDepth: make([]common.TrieStatisticsHandler, numTriesToPrint), + numTriesByType: make(map[common.TrieType]uint64), } } // Add adds the given trie statistics to the statistics collector -func (tsc *trieStatisticsCollector) Add(trieStats *TrieStatsDTO) { - if trieStats == nil { +func (tsc *trieStatisticsCollector) Add(trieStats common.TrieStatisticsHandler, trieType common.TrieType) { + if check.IfNil(trieStats) { log.Warn("programming error, nil trie stats received") return } - tsc.numNodes += trieStats.TotalNumNodes - tsc.triesSize += trieStats.TotalNodesSize - tsc.numDataTries++ + tsc.mutex.Lock() + defer tsc.mutex.Unlock() + + _, ok := tsc.trieStatsByType[trieType] + if !ok { + tsc.trieStatsByType[trieType] = NewTrieStatistics() + } - tsc.numTotalBranches += trieStats.NumBranchNodes - tsc.numTotalExtensions += trieStats.NumExtensionNodes - tsc.numTotalLeaves += trieStats.NumLeafNodes - tsc.totalSizeBranches += trieStats.BranchNodesSize - tsc.totalSizeExtensions += trieStats.ExtensionNodesSize - tsc.totalSizeLeaves += trieStats.LeafNodesSize + tsc.trieStatsByType[trieType].MergeTriesStatistics(trieStats) + tsc.numTriesByType[trieType]++ insertInSortedArray(tsc.triesBySize, trieStats, isLessSize) insertInSortedArray(tsc.triesByDepth, trieStats, isLessDeep) @@ -67,33 +60,105 @@ func (tsc *trieStatisticsCollector) Add(trieStats *TrieStatsDTO) { // Print will print all the collected statistics func (tsc *trieStatisticsCollector) Print() { + tsc.mutex.RLock() + defer tsc.mutex.RUnlock() + triesBySize := " \n top " + strconv.Itoa(numTriesToPrint) + " tries by size \n" triesByDepth := " \n top " + strconv.Itoa(numTriesToPrint) + " tries by depth \n" + totalNumNodes := uint64(0) + totalStateSize := uint64(0) + numMainTrieLeaves := uint64(0) + maxDepthMainTrie := uint32(0) + + for trieType, stats := range tsc.trieStatsByType { + totalNumNodes += stats.GetTotalNumNodes() + totalStateSize += stats.GetTotalNodesSize() + + if trieType == common.MainTrie { + numMainTrieLeaves = stats.GetNumLeafNodes() + maxDepthMainTrie = stats.GetMaxTrieDepth() + } + } + log.Debug("tries statistics", - "num of nodes", tsc.numNodes, - "total size", core.ConvertBytes(tsc.triesSize), - "num tries", tsc.numDataTries, - "total num branches", tsc.numTotalBranches, - "total num extensions", tsc.numTotalExtensions, - "total num leaves", tsc.numTotalLeaves, - "total size branches", core.ConvertBytes(tsc.totalSizeBranches), - "total size extensions", core.ConvertBytes(tsc.totalSizeExtensions), - "total size leaves", core.ConvertBytes(tsc.totalSizeLeaves), + "num of nodes", totalNumNodes, + "total size", core.ConvertBytes(totalStateSize), + "num tries by type", getNumTriesByTypeString(tsc.numTriesByType), + "num main trie leaves", numMainTrieLeaves, + "max depth main trie", maxDepthMainTrie, + triesBySize, getOrderedTries(tsc.triesBySize), triesByDepth, getOrderedTries(tsc.triesByDepth), ) + + for trieType, trieStats := range tsc.trieStatsByType { + message := fmt.Sprintf("migration stats for %v", trieType) + log.Debug(message, "stats", getMigrationStatsString(trieStats.GetLeavesMigrationStats())) + } + + if log.GetLevel() == logger.LogTrace { + tsc.printDetailedTriesStatistics() + } +} + +func getNumTriesByTypeString(numTriesByTypeMap map[common.TrieType]uint64) string { + var numTriesByTypeMapString []string + for trieType, numTries := range numTriesByTypeMap { + numTriesByTypeMapString = append(numTriesByTypeMapString, fmt.Sprintf("%v: %v", trieType, numTries)) + } + + sort.Slice(numTriesByTypeMapString, func(i, j int) bool { + return numTriesByTypeMapString[i] < numTriesByTypeMapString[j] + }) + + return strings.Join(numTriesByTypeMapString, ", ") +} + +func (tsc *trieStatisticsCollector) printDetailedTriesStatistics() { + totalNumBranches := uint64(0) + totalNumExtensions := uint64(0) + totalNumLeaves := uint64(0) + totalSizeBranches := uint64(0) + totalSizeExtensions := uint64(0) + totalSizeLeaves := uint64(0) + + for _, stats := range tsc.trieStatsByType { + totalNumBranches += stats.GetNumBranchNodes() + totalNumExtensions += stats.GetNumExtensionNodes() + totalNumLeaves += stats.GetNumLeafNodes() + totalSizeBranches += stats.GetBranchNodesSize() + totalSizeExtensions += stats.GetExtensionNodesSize() + totalSizeLeaves += stats.GetLeafNodesSize() + } + + log.Trace("detailed tries statistics", + "total num branches", totalNumBranches, + "total num extensions", totalNumExtensions, + "total num leaves", totalNumLeaves, + "total size branches", core.ConvertBytes(totalSizeBranches), + "total size extensions", core.ConvertBytes(totalSizeExtensions), + "total size leaves", core.ConvertBytes(totalSizeLeaves), + ) } // GetNumNodes returns the number of nodes func (tsc *trieStatisticsCollector) GetNumNodes() uint64 { - return tsc.numNodes + tsc.mutex.RLock() + defer tsc.mutex.RUnlock() + + totalNumNodes := uint64(0) + for _, stats := range tsc.trieStatsByType { + totalNumNodes += stats.GetTotalNumNodes() + } + + return totalNumNodes } -func getOrderedTries(tries []*TrieStatsDTO) string { +func getOrderedTries(tries []common.TrieStatisticsHandler) string { triesStats := make([]string, 0) for i := 0; i < len(tries); i++ { - if tries[i] == nil { + if check.IfNil(tries[i]) { continue } triesStats = append(triesStats, strings.Join(tries[i].ToString(), " ")) @@ -102,18 +167,18 @@ func getOrderedTries(tries []*TrieStatsDTO) string { return strings.Join(triesStats, "\n") } -func isLessSize(a *TrieStatsDTO, b *TrieStatsDTO) bool { - return a.TotalNodesSize < b.TotalNodesSize +func isLessSize(a common.TrieStatisticsHandler, b common.TrieStatisticsHandler) bool { + return a.GetTotalNodesSize() < b.GetTotalNodesSize() } -func isLessDeep(a *TrieStatsDTO, b *TrieStatsDTO) bool { - return a.MaxTrieDepth < b.MaxTrieDepth +func isLessDeep(a common.TrieStatisticsHandler, b common.TrieStatisticsHandler) bool { + return a.GetMaxTrieDepth() < b.GetMaxTrieDepth() } func insertInSortedArray( - array []*TrieStatsDTO, - ts *TrieStatsDTO, - isLess func(*TrieStatsDTO, *TrieStatsDTO) bool, + array []common.TrieStatisticsHandler, + ts common.TrieStatisticsHandler, + isLess func(common.TrieStatisticsHandler, common.TrieStatisticsHandler) bool, ) { insertIndex := numTriesToPrint lastNilIndex := numTriesToPrint diff --git a/trie/statistics/trieStatisticsCollector_test.go b/trie/statistics/trieStatisticsCollector_test.go index e63af29fd9f..37d371b642c 100644 --- a/trie/statistics/trieStatisticsCollector_test.go +++ b/trie/statistics/trieStatisticsCollector_test.go @@ -1,10 +1,12 @@ package statistics import ( + "fmt" "math/rand" "sort" "testing" + "github.com/multiversx/mx-chain-go/common" "github.com/stretchr/testify/assert" ) @@ -13,11 +15,11 @@ func TestSnapshotStatistics_Add(t *testing.T) { tsc := NewTrieStatisticsCollector() - tsc.Add(nil) // coverage, early exit + tsc.Add(nil, common.MainTrie) // coverage, early exit numInserts := 100 for i := 0; i < numInserts; i++ { - tsc.Add(getTrieStatsDTO(rand.Intn(numInserts), uint64(rand.Intn(numInserts)))) + tsc.Add(getTrieStats(rand.Intn(numInserts), uint64(rand.Intn(numInserts))), common.DataTrie) isSortedBySize := sort.SliceIsSorted(tsc.triesBySize, func(a, b int) bool { if tsc.triesBySize[b] == nil && tsc.triesBySize[a] == nil { return false @@ -26,7 +28,7 @@ func TestSnapshotStatistics_Add(t *testing.T) { return false } - return tsc.triesBySize[b].TotalNodesSize < tsc.triesBySize[a].TotalNodesSize + return tsc.triesBySize[b].GetTotalNodesSize() < tsc.triesBySize[a].GetTotalNodesSize() }) isSortedByDepth := sort.SliceIsSorted(tsc.triesByDepth, func(a, b int) bool { @@ -37,7 +39,7 @@ func TestSnapshotStatistics_Add(t *testing.T) { return false } - return tsc.triesByDepth[b].MaxTrieDepth < tsc.triesByDepth[a].MaxTrieDepth + return tsc.triesByDepth[b].GetMaxTrieDepth() < tsc.triesByDepth[a].GetMaxTrieDepth() }) assert.True(t, isSortedBySize) @@ -50,9 +52,23 @@ func TestSnapshotStatistics_Add(t *testing.T) { } } -func getTrieStatsDTO(maxLevel int, size uint64) *TrieStatsDTO { +func getTrieStats(maxLevel int, size uint64) common.TrieStatisticsHandler { ts := NewTrieStatistics() ts.AddBranchNode(maxLevel, size) - return ts.GetTrieStats() + return ts +} + +func TestGetNumTriesByTypeString(t *testing.T) { + t.Parallel() + + numMainTries := 1 + numDataTries := 500 + numTriesByType := make(map[common.TrieType]uint64) + numTriesByType[common.MainTrie] = uint64(numMainTries) + numTriesByType[common.DataTrie] = uint64(numDataTries) + + numTriesByTypeString := getNumTriesByTypeString(numTriesByType) + expectedRes := fmt.Sprintf("%v: %v, %v: %v", common.DataTrie, numDataTries, common.MainTrie, numMainTries) + assert.Equal(t, expectedRes, numTriesByTypeString) } diff --git a/trie/statistics/trieStatistics_test.go b/trie/statistics/trieStatistics_test.go index d0870ced3dd..584bb4a6496 100644 --- a/trie/statistics/trieStatistics_test.go +++ b/trie/statistics/trieStatistics_test.go @@ -7,7 +7,6 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestTrieStatistics_AddBranchNode(t *testing.T) { @@ -49,13 +48,15 @@ func TestTrieStatistics_AddLeafNode(t *testing.T) { level := 2 size := uint64(15) - ts.AddLeafNode(level, size) - ts.AddLeafNode(level+1, size) - ts.AddLeafNode(level-1, size) + ts.AddLeafNode(level, size, 0) + ts.AddLeafNode(level+1, size, 0) + ts.AddLeafNode(level-1, size, 1) assert.Equal(t, uint64(3), ts.leafNodes.numNodes) assert.Equal(t, 3*size, ts.leafNodes.nodesSize) assert.Equal(t, 3*size, ts.leafNodes.nodesSize) assert.Equal(t, uint32(level+1), ts.maxTrieDepth) + assert.Equal(t, uint64(2), ts.migrationStats[0]) + assert.Equal(t, uint64(1), ts.migrationStats[1]) } func TestTrieStatistics_AddAccountInfo(t *testing.T) { @@ -97,53 +98,87 @@ func TestTrieStatistics_GetTrieStats(t *testing.T) { ts.AddExtensionNode(i, uint64(extensionSize)) } for i := 0; i < numLeaves; i++ { - ts.AddLeafNode(i, uint64(leafSize)) + ts.AddLeafNode(i, uint64(leafSize), 0) } - stats := ts.GetTrieStats() - assert.Equal(t, uint32(numBranches-1), stats.MaxTrieDepth) - assert.Equal(t, uint64(expectedBranchesSize), stats.BranchNodesSize) - assert.Equal(t, uint64(expectedExtensionsSize), stats.ExtensionNodesSize) - assert.Equal(t, uint64(expectedLeavesSize), stats.LeafNodesSize) - assert.Equal(t, uint64(expectedLeavesSize+expectedBranchesSize+expectedExtensionsSize), stats.TotalNodesSize) - assert.Equal(t, totalNumNodes, stats.TotalNumNodes) - assert.Equal(t, uint64(numBranches), stats.NumBranchNodes) - assert.Equal(t, uint64(numExtensions), stats.NumExtensionNodes) - assert.Equal(t, uint64(numLeaves), stats.NumLeafNodes) + assert.Equal(t, uint32(numBranches-1), ts.GetMaxTrieDepth()) + assert.Equal(t, uint64(expectedBranchesSize), ts.GetBranchNodesSize()) + assert.Equal(t, uint64(expectedExtensionsSize), ts.GetExtensionNodesSize()) + assert.Equal(t, uint64(expectedLeavesSize), ts.GetLeafNodesSize()) + assert.Equal(t, uint64(expectedLeavesSize+expectedBranchesSize+expectedExtensionsSize), ts.GetTotalNodesSize()) + assert.Equal(t, totalNumNodes, ts.GetTotalNumNodes()) + assert.Equal(t, uint64(numBranches), ts.GetNumBranchNodes()) + assert.Equal(t, uint64(numExtensions), ts.GetNumExtensionNodes()) + assert.Equal(t, uint64(numLeaves), ts.GetNumLeafNodes()) + assert.Equal(t, uint64(numLeaves), ts.GetLeavesMigrationStats()[0]) } -func TestTrieStatsDTO_ToString(t *testing.T) { +func TestTrieStatistics_MergeTriesStatistics(t *testing.T) { t.Parallel() - tsd := TrieStatsDTO{ - Address: "address", - RootHash: []byte("root hash"), - TotalNodesSize: 1, - TotalNumNodes: 1, - MaxTrieDepth: 1, - BranchNodesSize: 1, - NumBranchNodes: 1, - ExtensionNodesSize: 1, - NumExtensionNodes: 1, - LeafNodesSize: 1, - NumLeafNodes: 1, - } + leafSize := uint64(10) + branchSize := uint64(8) + extensionSize := uint64(5) - expectedLines := []string{ - fmt.Sprintf("address %v,", tsd.Address), - fmt.Sprintf("rootHash %v,", hex.EncodeToString(tsd.RootHash)), - fmt.Sprintf("total trie size = %v,", core.ConvertBytes(tsd.TotalNodesSize)), - fmt.Sprintf("num trie nodes = %v,", tsd.TotalNumNodes), - fmt.Sprintf("max trie depth = %v,", tsd.MaxTrieDepth), - fmt.Sprintf("branch nodes size %v,", core.ConvertBytes(tsd.BranchNodesSize)), - fmt.Sprintf("extension nodes size %v,", core.ConvertBytes(tsd.ExtensionNodesSize)), - fmt.Sprintf("leaf nodes size %v,", core.ConvertBytes(tsd.LeafNodesSize)), - fmt.Sprintf("num branches %v,", tsd.NumBranchNodes), - fmt.Sprintf("num extensions %v,", tsd.NumExtensionNodes), - fmt.Sprintf("num leaves %v", tsd.NumLeafNodes), - } - stringDTO := tsd.ToString() - for i, line := range stringDTO { - require.Equal(t, expectedLines[i], line) - } + ts := NewTrieStatistics() + newTs := NewTrieStatistics() + newTs.AddLeafNode(2, leafSize, 0) + newTs.AddLeafNode(3, leafSize, 1) + newTs.AddBranchNode(1, branchSize) + newTs.AddExtensionNode(3, extensionSize) + + ts.MergeTriesStatistics(newTs) + + assert.Equal(t, uint32(3), ts.GetMaxTrieDepth()) + assert.Equal(t, branchSize, ts.GetBranchNodesSize()) + assert.Equal(t, extensionSize, ts.GetExtensionNodesSize()) + assert.Equal(t, 2*leafSize, ts.GetLeafNodesSize()) + assert.Equal(t, branchSize+extensionSize+2*leafSize, ts.GetTotalNodesSize()) + assert.Equal(t, uint64(4), ts.GetTotalNumNodes()) + assert.Equal(t, uint64(1), ts.GetNumBranchNodes()) + assert.Equal(t, uint64(1), ts.GetNumExtensionNodes()) + assert.Equal(t, uint64(2), ts.GetNumLeafNodes()) + assert.Equal(t, uint64(1), ts.GetLeavesMigrationStats()[0]) + assert.Equal(t, uint64(1), ts.GetLeavesMigrationStats()[1]) + + newTs = NewTrieStatistics() + newTs.AddLeafNode(4, leafSize, 0) + newTs.AddLeafNode(1, leafSize, 1) + newTs.AddBranchNode(1, branchSize) + newTs.AddExtensionNode(3, extensionSize) + + ts.MergeTriesStatistics(newTs) + totalNodesSize := branchSize*2 + extensionSize*2 + leafSize*4 + + assert.Equal(t, uint32(4), ts.GetMaxTrieDepth()) + assert.Equal(t, branchSize*2, ts.GetBranchNodesSize()) + assert.Equal(t, extensionSize*2, ts.GetExtensionNodesSize()) + assert.Equal(t, leafSize*4, ts.GetLeafNodesSize()) + assert.Equal(t, totalNodesSize, ts.GetTotalNodesSize()) + assert.Equal(t, uint64(8), ts.GetTotalNumNodes()) + assert.Equal(t, uint64(2), ts.GetNumBranchNodes()) + assert.Equal(t, uint64(2), ts.GetNumExtensionNodes()) + assert.Equal(t, uint64(4), ts.GetNumLeafNodes()) + assert.Equal(t, uint64(2), ts.GetLeavesMigrationStats()[0]) + assert.Equal(t, uint64(2), ts.GetLeavesMigrationStats()[1]) + + address := "address" + rootHash := []byte("rootHash") + ts.AddAccountInfo(address, rootHash) + + trieStatsStrings := ts.ToString() + assert.Equal(t, 13, len(trieStatsStrings)) + assert.Equal(t, fmt.Sprintf("address %v,", address), trieStatsStrings[0]) + assert.Equal(t, fmt.Sprintf("rootHash %v,", hex.EncodeToString(rootHash)), trieStatsStrings[1]) + assert.Equal(t, fmt.Sprintf("total trie size = %v,", core.ConvertBytes(totalNodesSize)), trieStatsStrings[2]) + assert.Equal(t, fmt.Sprintf("num trie nodes = %v,", 8), trieStatsStrings[3]) + assert.Equal(t, fmt.Sprintf("max trie depth = %v,", 4), trieStatsStrings[4]) + assert.Equal(t, fmt.Sprintf("branch nodes size %v,", core.ConvertBytes(16)), trieStatsStrings[5]) + assert.Equal(t, fmt.Sprintf("extension nodes size %v,", core.ConvertBytes(10)), trieStatsStrings[6]) + assert.Equal(t, fmt.Sprintf("leaf nodes size %v,", core.ConvertBytes(40)), trieStatsStrings[7]) + assert.Equal(t, fmt.Sprintf("num branches %v,", 2), trieStatsStrings[8]) + assert.Equal(t, fmt.Sprintf("num extensions %v,", 2), trieStatsStrings[9]) + assert.Equal(t, fmt.Sprintf("num leaves %v", 4), trieStatsStrings[10]) + assert.Equal(t, fmt.Sprintf("num leaves with %s version = %v", core.AutoBalanceEnabledString, 2), trieStatsStrings[11]) + assert.Equal(t, fmt.Sprintf("num leaves with %s version = %v", core.NotSpecifiedString, 2), trieStatsStrings[12]) } diff --git a/trie/sync_test.go b/trie/sync_test.go index fcbf0ec04f7..ab5083eb85a 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/stretchr/testify/assert" @@ -34,7 +35,7 @@ func createMockArgument(timeout time.Duration) ArgTrieSyncer { InterceptedNodes: testscommon.NewCacherMock(), DB: trieStorage, Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &testscommon.MarshalizerMock{}, + Marshalizer: &marshallerMock.MarshalizerMock{}, ShardId: 0, Topic: "topic", TrieSyncStatistics: statistics.NewTrieSyncStatistics(), diff --git a/trie/trieStorageManager.go b/trie/trieStorageManager.go index 0fee55ade72..a8963058169 100644 --- a/trie/trieStorageManager.go +++ b/trie/trieStorageManager.go @@ -469,7 +469,15 @@ func (tsm *trieStorageManager) takeSnapshot(snapshotEntry *snapshotsQueueEntry, } stats.AddAccountInfo(snapshotEntry.address, snapshotEntry.rootHash) - snapshotEntry.stats.AddTrieStats(stats.GetTrieStats()) + snapshotEntry.stats.AddTrieStats(stats, getTrieTypeFromAddress(snapshotEntry.address)) +} + +func getTrieTypeFromAddress(address string) common.TrieType { + if len(address) == 0 { + return common.MainTrie + } + + return common.DataTrie } func (tsm *trieStorageManager) takeCheckpoint(checkpointEntry *snapshotsQueueEntry, msh marshal.Marshalizer, hsh hashing.Hasher, ctx context.Context, goRoutinesThrottler core.Throttler) { @@ -504,7 +512,7 @@ func (tsm *trieStorageManager) takeCheckpoint(checkpointEntry *snapshotsQueueEnt } stats.AddAccountInfo(checkpointEntry.address, checkpointEntry.rootHash) - checkpointEntry.stats.AddTrieStats(stats.GetTrieStats()) + checkpointEntry.stats.AddTrieStats(stats, getTrieTypeFromAddress(checkpointEntry.address)) } func treatSnapshotError(err error, message string, rootHash []byte, mainTrieRootHash []byte) { diff --git a/update/errors.go b/update/errors.go index dc94a334854..938ae2020ee 100644 --- a/update/errors.go +++ b/update/errors.go @@ -292,3 +292,6 @@ var ErrNilAppStatusHandler = errors.New("nil app status handler") // ErrNilAddressConverter signals that a nil address converter was provided var ErrNilAddressConverter = errors.New("nil address converter") + +// ErrNilEnableEpochsHandler signals that a nil enable epochs handler was provided +var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") diff --git a/update/factory/accountDBSyncerContainerFactory.go b/update/factory/accountDBSyncerContainerFactory.go index ec1929c754f..58684996dca 100644 --- a/update/factory/accountDBSyncerContainerFactory.go +++ b/update/factory/accountDBSyncerContainerFactory.go @@ -35,6 +35,7 @@ type ArgsNewAccountsDBSyncersContainerFactory struct { TrieSyncerVersion int CheckNodesOnDisk bool AddressPubKeyConverter core.PubkeyConverter + EnableEpochsHandler common.EnableEpochsHandler } type accountDBSyncersContainerFactory struct { @@ -52,6 +53,7 @@ type accountDBSyncersContainerFactory struct { trieSyncerVersion int checkNodesOnDisk bool addressPubKeyConverter core.PubkeyConverter + enableEpochsHandler common.EnableEpochsHandler } // NewAccountsDBSContainerFactory creates a factory for trie syncers container @@ -87,6 +89,9 @@ func NewAccountsDBSContainerFactory(args ArgsNewAccountsDBSyncersContainerFactor if check.IfNil(args.AddressPubKeyConverter) { return nil, update.ErrNilPubKeyConverter } + if check.IfNil(args.EnableEpochsHandler) { + return nil, update.ErrNilEnableEpochsHandler + } t := &accountDBSyncersContainerFactory{ shardCoordinator: args.ShardCoordinator, @@ -102,6 +107,7 @@ func NewAccountsDBSContainerFactory(args ArgsNewAccountsDBSyncersContainerFactor trieSyncerVersion: args.TrieSyncerVersion, checkNodesOnDisk: args.CheckNodesOnDisk, addressPubKeyConverter: args.AddressPubKeyConverter, + enableEpochsHandler: args.EnableEpochsHandler, } return t, nil @@ -151,6 +157,7 @@ func (a *accountDBSyncersContainerFactory) createUserAccountsSyncer(shardId uint CheckNodesOnDisk: a.checkNodesOnDisk, UserAccountsSyncStatisticsHandler: statistics.NewTrieSyncStatistics(), AppStatusHandler: disabled.NewAppStatusHandler(), + EnableEpochsHandler: a.enableEpochsHandler, }, ShardId: shardId, Throttler: thr, @@ -180,6 +187,7 @@ func (a *accountDBSyncersContainerFactory) createValidatorAccountsSyncer(shardId CheckNodesOnDisk: a.checkNodesOnDisk, UserAccountsSyncStatisticsHandler: statistics.NewTrieSyncStatistics(), AppStatusHandler: disabled.NewAppStatusHandler(), + EnableEpochsHandler: a.enableEpochsHandler, }, } accountSyncer, err := syncer.NewValidatorAccountsSyncer(args) diff --git a/update/factory/dataTrieFactory.go b/update/factory/dataTrieFactory.go index db3fa4ea71b..bd3f7f178c3 100644 --- a/update/factory/dataTrieFactory.go +++ b/update/factory/dataTrieFactory.go @@ -29,6 +29,7 @@ type ArgsNewDataTrieFactory struct { Marshalizer marshal.Marshalizer Hasher hashing.Hasher ShardCoordinator sharding.Coordinator + EnableEpochsHandler common.EnableEpochsHandler MaxTrieLevelInMemory uint } @@ -37,6 +38,7 @@ type dataTrieFactory struct { trieStorage common.StorageManager marshalizer marshal.Marshalizer hasher hashing.Hasher + enableEpochsHandler common.EnableEpochsHandler maxTrieLevelInMemory uint } @@ -54,6 +56,9 @@ func NewDataTrieFactory(args ArgsNewDataTrieFactory) (*dataTrieFactory, error) { if check.IfNil(args.Hasher) { return nil, update.ErrNilHasher } + if check.IfNil(args.EnableEpochsHandler) { + return nil, update.ErrNilEnableEpochsHandler + } dbConfig := storageFactory.GetDBFromConfig(args.StorageConfig.DB) dbConfig.FilePath = path.Join(args.SyncFolder, args.StorageConfig.DB.FilePath) @@ -92,6 +97,7 @@ func NewDataTrieFactory(args ArgsNewDataTrieFactory) (*dataTrieFactory, error) { marshalizer: args.Marshalizer, hasher: args.Hasher, maxTrieLevelInMemory: args.MaxTrieLevelInMemory, + enableEpochsHandler: args.EnableEpochsHandler, } return d, nil @@ -127,7 +133,7 @@ func (d *dataTrieFactory) Create() (common.TriesHolder, error) { } func (d *dataTrieFactory) createAndAddOneTrie(shId uint32, accType genesis.Type, container common.TriesHolder) error { - dataTrie, err := trie.NewTrie(d.trieStorage, d.marshalizer, d.hasher, d.maxTrieLevelInMemory) + dataTrie, err := trie.NewTrie(d.trieStorage, d.marshalizer, d.hasher, d.enableEpochsHandler, d.maxTrieLevelInMemory) if err != nil { return err } diff --git a/update/factory/exportHandlerFactory.go b/update/factory/exportHandlerFactory.go index 98eb7bd8750..bb80be0101a 100644 --- a/update/factory/exportHandlerFactory.go +++ b/update/factory/exportHandlerFactory.go @@ -328,6 +328,7 @@ func (e *exportHandlerFactory) Create() (update.ExportHandler, error) { Hasher: e.CoreComponents.Hasher(), ShardCoordinator: e.shardCoordinator, MaxTrieLevelInMemory: e.maxTrieLevelInMemory, + EnableEpochsHandler: e.CoreComponents.EnableEpochsHandler(), } dataTriesContainerFactory, err := NewDataTrieFactory(argsDataTrieFactory) if err != nil { @@ -416,6 +417,7 @@ func (e *exportHandlerFactory) Create() (update.ExportHandler, error) { TrieSyncerVersion: e.trieSyncerVersion, CheckNodesOnDisk: e.checkNodesOnDisk, AddressPubKeyConverter: e.CoreComponents.AddressPubKeyConverter(), + EnableEpochsHandler: e.CoreComponents.EnableEpochsHandler(), } accountsDBSyncerFactory, err := NewAccountsDBSContainerFactory(argsAccountsSyncers) if err != nil { diff --git a/update/genesis/base.go b/update/genesis/base.go index dd3bed6d63b..a98eafd6651 100644 --- a/update/genesis/base.go +++ b/update/genesis/base.go @@ -11,6 +11,9 @@ import ( "github.com/multiversx/mx-chain-core-go/data/rewardTx" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/update" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -89,10 +92,21 @@ func NewObject(objType Type) (interface{}, error) { } // NewEmptyAccount returns a new account according to the given type -func NewEmptyAccount(accType Type, address []byte) (vmcommon.AccountHandler, error) { +func NewEmptyAccount( + accType Type, + address []byte, + hasher hashing.Hasher, + marshaller marshal.Marshalizer, + enableEpochsHandler common.EnableEpochsHandler, +) (vmcommon.AccountHandler, error) { switch accType { case UserAccount: - return state.NewUserAccount(address) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshaller, + EnableEpochsHandler: enableEpochsHandler, + } + return state.NewUserAccount(address, argsAccCreation) case ValidatorAccount: return state.NewPeerAccount(address) case DataTrie: diff --git a/update/genesis/export.go b/update/genesis/export.go index 51a6fc237b6..06bc16dc8a0 100644 --- a/update/genesis/export.go +++ b/update/genesis/export.go @@ -20,6 +20,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/update" logger "github.com/multiversx/mx-chain-logger-go" @@ -298,7 +299,13 @@ func (se *stateExport) exportTrie(key string, trie common.Trie) error { LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), ErrChan: errChan.NewErrChanWrapper(), } - err = trie.GetAllLeavesOnChannel(leavesChannels, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) + err = trie.GetAllLeavesOnChannel( + leavesChannels, + context.Background(), + rootHash, + keyBuilder.NewKeyBuilder(), + parsers.NewMainTrieLeafParser(), + ) if err != nil { return err } diff --git a/update/genesis/export_test.go b/update/genesis/export_test.go index d5587d031ae..8b2c3a17d40 100644 --- a/update/genesis/export_test.go +++ b/update/genesis/export_test.go @@ -288,7 +288,7 @@ func TestStateExport_ExportTrieShouldExportNodesSetupJson(t *testing.T) { RootCalled: func() ([]byte, error) { return []byte{}, nil }, - GetAllLeavesOnChannelCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, _ common.TrieLeafParser) error { mm := &mock.MarshalizerMock{} valInfo := &state.ValidatorInfo{List: string(common.EligibleList)} pacB, _ := mm.Marshal(valInfo) @@ -337,7 +337,7 @@ func TestStateExport_ExportTrieShouldExportNodesSetupJson(t *testing.T) { RootCalled: func() ([]byte, error) { return []byte{}, nil }, - GetAllLeavesOnChannelCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { + GetAllLeavesOnChannelCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, _ common.TrieLeafParser) error { mm := &mock.MarshalizerMock{} valInfo := &state.ValidatorInfo{List: string(common.EligibleList)} pacB, _ := mm.Marshal(valInfo) diff --git a/update/genesis/import.go b/update/genesis/import.go index e740564c424..7f5d2166cee 100644 --- a/update/genesis/import.go +++ b/update/genesis/import.go @@ -17,6 +17,7 @@ import ( commonDisabled "github.com/multiversx/mx-chain-go/common/disabled" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager/disabled" @@ -37,6 +38,7 @@ type ArgsNewStateImport struct { TrieStorageManagers map[string]common.StorageManager HardforkStorer update.HardforkStorer AddressConverter core.PubkeyConverter + EnableEpochsHandler common.EnableEpochsHandler } type stateImport struct { @@ -56,6 +58,7 @@ type stateImport struct { storageConfig config.StorageConfig trieStorageManagers map[string]common.StorageManager addressConverter core.PubkeyConverter + enableEpochsHandler common.EnableEpochsHandler } // NewStateImport creates an importer which reads all the files for a new start @@ -75,6 +78,9 @@ func NewStateImport(args ArgsNewStateImport) (*stateImport, error) { if check.IfNil(args.AddressConverter) { return nil, update.ErrNilAddressConverter } + if check.IfNil(args.EnableEpochsHandler) { + return nil, errors.ErrNilEnableEpochsHandler + } st := &stateImport{ genesisHeaders: make(map[uint32]data.HeaderHandler), @@ -91,6 +97,7 @@ func NewStateImport(args ArgsNewStateImport) (*stateImport, error) { shardID: args.ShardID, hardforkStorer: args.HardforkStorer, addressConverter: args.AddressConverter, + enableEpochsHandler: args.EnableEpochsHandler, } return st, nil @@ -271,10 +278,20 @@ func (si *stateImport) importMiniBlocks(identifier string, keys [][]byte) error return nil } -func newAccountCreator(accType Type) (state.AccountFactory, error) { +func newAccountCreator( + accType Type, + hasher hashing.Hasher, + marshaller marshal.Marshalizer, + handler common.EnableEpochsHandler, +) (state.AccountFactory, error) { switch accType { case UserAccount: - return factory.NewAccountCreator(), nil + args := state.ArgsAccountCreation{ + Hasher: hasher, + Marshaller: marshaller, + EnableEpochsHandler: handler, + } + return factory.NewAccountCreator(args) case ValidatorAccount: return factory.NewPeerAccountCreator(), nil } @@ -297,7 +314,7 @@ func (si *stateImport) getTrie(shardID uint32, accType Type) (common.Trie, error trieStorageManager = si.trieStorageManagers[dataRetriever.PeerAccountsUnit.String()] } - trieForShard, err := trie.NewTrie(trieStorageManager, si.marshalizer, si.hasher, maxTrieLevelInMemory) + trieForShard, err := trie.NewTrie(trieStorageManager, si.marshalizer, si.hasher, si.enableEpochsHandler, maxTrieLevelInMemory) if err != nil { return nil, err } @@ -329,7 +346,7 @@ func (si *stateImport) importDataTrie(identifier string, shID uint32, keys [][]b return fmt.Errorf("%w wanted a roothash", update.ErrWrongTypeAssertion) } - dataTrie, err := trie.NewTrie(si.trieStorageManagers[dataRetriever.UserAccountsUnit.String()], si.marshalizer, si.hasher, maxTrieLevelInMemory) + dataTrie, err := trie.NewTrie(si.trieStorageManagers[dataRetriever.UserAccountsUnit.String()], si.marshalizer, si.hasher, si.enableEpochsHandler, maxTrieLevelInMemory) if err != nil { return err } @@ -359,7 +376,7 @@ func (si *stateImport) importDataTrie(identifier string, shID uint32, keys [][]b err = update.ErrKeyTypeMismatch break } - + // TODO this will not work for a partially migrated trie err = dataTrie.Update(address, value) if err != nil { break @@ -388,7 +405,7 @@ func (si *stateImport) importDataTrie(identifier string, shID uint32, keys [][]b } func (si *stateImport) getAccountsDB(accType Type, shardID uint32) (state.AccountsDBImporter, common.Trie, error) { - accountFactory, err := newAccountCreator(accType) + accountFactory, err := newAccountCreator(accType, si.hasher, si.marshalizer, si.enableEpochsHandler) if err != nil { return nil, nil, err } @@ -523,7 +540,7 @@ func (si *stateImport) unMarshalAndSaveAccount( accountsDB state.AccountsDBImporter, mainTrie common.Trie, ) error { - account, err := NewEmptyAccount(accType, address) + account, err := NewEmptyAccount(accType, address, si.hasher, si.marshalizer, si.enableEpochsHandler) if err != nil { return err } diff --git a/update/genesis/import_test.go b/update/genesis/import_test.go index 30dd5f95492..538a9eed617 100644 --- a/update/genesis/import_test.go +++ b/update/genesis/import_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/storageManager" "github.com/multiversx/mx-chain-go/update" @@ -71,6 +72,7 @@ func TestNewStateImport(t *testing.T) { Hasher: &mock.HasherStub{}, TrieStorageManagers: trieStorageManagers, AddressConverter: &testscommon.PubkeyConverterMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }, exError: nil, }, @@ -98,6 +100,7 @@ func TestImportAll(t *testing.T) { ShardID: 0, StorageConfig: config.StorageConfig{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } importState, _ := NewStateImport(args) @@ -133,6 +136,7 @@ func TestStateImport_ImportUnFinishedMetaBlocksShouldWork(t *testing.T) { ShardID: 0, StorageConfig: config.StorageConfig{}, AddressConverter: &testscommon.PubkeyConverterMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } importState, _ := NewStateImport(args) diff --git a/vm/factory/systemSCFactory_test.go b/vm/factory/systemSCFactory_test.go index 5ea4e2b777e..a0aa8ecbb3c 100644 --- a/vm/factory/systemSCFactory_test.go +++ b/vm/factory/systemSCFactory_test.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/mock" @@ -77,7 +78,7 @@ func createMockNewSystemScFactoryArgs() ArgsNewSystemSCFactory { }, AddressPubKeyConverter: &testscommon.PubkeyConverterMock{}, ShardCoordinator: &mock.ShardCoordinatorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } diff --git a/vm/gasCost.go b/vm/gasCost.go index 33254a1a204..c2263db787a 100644 --- a/vm/gasCost.go +++ b/vm/gasCost.go @@ -56,6 +56,8 @@ type BuiltInCost struct { ESDTNFTAddUri uint64 ESDTNFTUpdateAttributes uint64 ESDTNFTMultiTransfer uint64 + TrieLoadPerNode uint64 + TrieStorePerNode uint64 } // GasCost holds all the needed gas costs for system smart contracts diff --git a/vm/mock/blockChainHookStub.go b/vm/mock/blockChainHookStub.go index 0c326296025..c16ab610130 100644 --- a/vm/mock/blockChainHookStub.go +++ b/vm/mock/blockChainHookStub.go @@ -2,6 +2,9 @@ package mock import ( "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) @@ -74,7 +77,12 @@ func (b *BlockChainHookStub) GetUserAccount(address []byte) (vmcommon.UserAccoun return b.GetUserAccountCalled(address) } - return state.NewUserAccount(address) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + return state.NewUserAccount(address, argsAccCreation) } // GetShardOfAddress - diff --git a/vm/systemSmartContracts/defaults/gasMap.go b/vm/systemSmartContracts/defaults/gasMap.go index 98a1ce483d9..822b61b3651 100644 --- a/vm/systemSmartContracts/defaults/gasMap.go +++ b/vm/systemSmartContracts/defaults/gasMap.go @@ -50,6 +50,8 @@ func FillGasMapBuiltInCosts(value uint64) map[string]uint64 { gasMap["SetGuardian"] = value gasMap["GuardAccount"] = value gasMap["UnGuardAccount"] = value + gasMap["TrieLoadPerNode"] = value + gasMap["TrieStorePerNode"] = value return gasMap } diff --git a/vm/systemSmartContracts/delegationManager_test.go b/vm/systemSmartContracts/delegationManager_test.go index ed374a69b24..bfc565e3572 100644 --- a/vm/systemSmartContracts/delegationManager_test.go +++ b/vm/systemSmartContracts/delegationManager_test.go @@ -10,7 +10,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/config" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/mock" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -37,7 +37,7 @@ func createMockArgumentsForDelegationManager() ArgsNewDelegationManager { ConfigChangeAddress: configChangeAddress, GasCost: vm.GasCost{MetaChainSystemSCsCost: vm.MetaChainSystemSCsCost{ESDTIssue: 10}}, Marshalizer: &mock.MarshalizerMock{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsDelegationManagerFlagEnabledField: true, IsValidatorToDelegationFlagEnabledField: true, IsMultiClaimOnDelegationEnabledField: true, @@ -189,7 +189,7 @@ func TestDelegationManagerSystemSC_ExecuteWithDelegationManagerDisabled(t *testi args := createMockArgumentsForDelegationManager() eei := createDefaultEei() args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) dm, _ := NewDelegationManagerSystemSC(args) enableEpochsHandler.IsDelegationManagerFlagEnabledField = false @@ -680,7 +680,7 @@ func TestDelegationManagerSystemSC_checkValidatorToDelegationInput(t *testing.T) args.Eei = eei args.GasCost.MetaChainSystemSCsCost.ValidatorToDelegation = 100 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) d, _ := NewDelegationManagerSystemSC(args) vmInput := getDefaultVmInputForDelegationManager("createNewDelegationContract", [][]byte{maxDelegationCap, serviceFee}) @@ -722,7 +722,7 @@ func TestDelegationManagerSystemSC_MakeNewContractFromValidatorData(t *testing.T args.Eei = eei args.GasCost.MetaChainSystemSCsCost.ValidatorToDelegation = 100 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) d, _ := NewDelegationManagerSystemSC(args) vmInput := getDefaultVmInputForDelegationManager("makeNewContractFromValidatorData", [][]byte{maxDelegationCap, serviceFee}) _ = d.init(&vmcommon.ContractCallInput{VMInput: vmcommon.VMInput{CallValue: big.NewInt(0)}}) @@ -761,7 +761,7 @@ func TestDelegationManagerSystemSC_mergeValidatorToDelegationSameOwner(t *testin args.Eei = eei args.GasCost.MetaChainSystemSCsCost.ValidatorToDelegation = 100 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) d, _ := NewDelegationManagerSystemSC(args) vmInput := getDefaultVmInputForDelegationManager("mergeValidatorToDelegationSameOwner", [][]byte{maxDelegationCap, serviceFee}) _ = d.init(&vmcommon.ContractCallInput{VMInput: vmcommon.VMInput{CallValue: big.NewInt(0)}}) @@ -846,7 +846,7 @@ func TestDelegationManagerSystemSC_mergeValidatorToDelegationWithWhiteListInvali serviceFee := []byte{10} eei.returnMessage = "" vmInput := getDefaultVmInputForDelegationManager("mergeValidatorToDelegationWithWhitelist", [][]byte{maxDelegationCap, serviceFee}) - enableEpochsHandler, _ := d.enableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := d.enableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsValidatorToDelegationFlagEnabledField = false returnCode := d.Execute(vmInput) assert.Equal(t, vmcommon.UserError, returnCode) @@ -1096,7 +1096,7 @@ func TestDelegationManagerSystemSC_ClaimMultipleDelegationFails(t *testing.T) { createSystemSCContainer(eei), ) - enableHandlerStub := &testscommon.EnableEpochsHandlerStub{ + enableHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMultiClaimOnDelegationEnabledField: false, IsDelegationManagerFlagEnabledField: true, } diff --git a/vm/systemSmartContracts/delegation_test.go b/vm/systemSmartContracts/delegation_test.go index 86d93954064..ae56534698f 100644 --- a/vm/systemSmartContracts/delegation_test.go +++ b/vm/systemSmartContracts/delegation_test.go @@ -14,7 +14,7 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process/smartContract/hooks" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/mock" @@ -45,7 +45,7 @@ func createMockArgumentsForDelegation() ArgsNewDelegation { EndOfEpochAddress: vm.EndOfEpochAddress, GovernanceSCAddress: vm.GovernanceSCAddress, AddTokensAddress: bytes.Repeat([]byte{1}, 32), - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsDelegationSmartContractFlagEnabledField: true, IsStakingV2FlagEnabledForActivationEpochCompletedField: true, IsAddTokensToDelegationFlagEnabledField: true, @@ -63,7 +63,7 @@ func addValidatorAndStakingScToVmContext(eei *vmContext) { validatorArgs.Eei = eei validatorArgs.StakingSCConfig.GenesisNodePrice = "100" validatorArgs.StakingSCAddress = vm.StakingSCAddress - enableEpochsHandler, _ := validatorArgs.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := validatorArgs.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) validatorSc, _ := NewValidatorSmartContract(validatorArgs) stakingArgs := createMockStakingScArguments() @@ -143,7 +143,7 @@ func createDelegationContractAndEEI() (*delegation, *vmContext) { InputParser: &mock.ArgumentParserMock{}, ValidatorAccountsDB: &stateMock.AccountsStub{}, ChanceComputer: &mock.RaterMock{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }) systemSCContainerStub := &mock.SystemSCContainerStub{GetCalled: func(key []byte) (vm.SystemSmartContract, error) { return &mock.SystemSCStub{ExecuteCalled: func(args *vmcommon.ContractCallInput) vmcommon.ReturnCode { @@ -302,7 +302,7 @@ func TestDelegationSystemSC_ExecuteDelegationDisabledShouldErr(t *testing.T) { args := createMockArgumentsForDelegation() eei := createDefaultEei() args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) d, _ := NewDelegationSystemSC(args) enableEpochsHandler.IsDelegationSmartContractFlagEnabledField = false vmInput := getDefaultVmInputForFunc("addNodes", [][]byte{}) @@ -1078,7 +1078,7 @@ func TestDelegationSystemSC_ExecuteUnStakeNodesAtEndOfEpoch(t *testing.T) { validatorArgs := createMockArgumentsForValidatorSC() validatorArgs.Eei = eei validatorArgs.StakingSCConfig.GenesisNodePrice = "100" - enableEpochsHandler, _ := validatorArgs.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := validatorArgs.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true validatorArgs.StakingSCAddress = vm.StakingSCAddress validatorSc, _ := NewValidatorSmartContract(validatorArgs) @@ -2599,7 +2599,7 @@ func prepareReDelegateRewardsComponents( args.Eei = eei args.DelegationSCConfig.MaxServiceFee = 10000 args.DelegationSCConfig.MinServiceFee = 0 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsReDelegateBelowMinCheckFlagEnabledField = extraCheckEpoch == 0 d, _ := NewDelegationSystemSC(args) vmInput := getDefaultVmInputForFunc(core.SCDeployInitFunctionName, [][]byte{big.NewInt(0).Bytes(), big.NewInt(0).Bytes()}) @@ -3899,7 +3899,7 @@ func TestDelegation_checkArgumentsForValidatorToDelegation(t *testing.T) { return vmcommon.Ok }}, nil }}) - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) args.Eei = eei args.DelegationSCConfig.MaxServiceFee = 10000 @@ -4036,7 +4036,7 @@ func TestDelegation_initFromValidatorData(t *testing.T) { return vmcommon.Ok }}, nil }} - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) _ = eei.SetSystemSCContainer(systemSCContainerStub) @@ -4165,7 +4165,7 @@ func TestDelegation_mergeValidatorDataToDelegation(t *testing.T) { return vmcommon.Ok }}, nil }} - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) _ = eei.SetSystemSCContainer(systemSCContainerStub) @@ -4304,7 +4304,7 @@ func TestDelegation_whitelistForMerge(t *testing.T) { return vmcommon.Ok }}, nil }} - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) _ = eei.SetSystemSCContainer(systemSCContainerStub) @@ -4383,7 +4383,7 @@ func TestDelegation_deleteWhitelistForMerge(t *testing.T) { return vmcommon.Ok }}, nil }} - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) _ = eei.SetSystemSCContainer(systemSCContainerStub) @@ -4441,7 +4441,7 @@ func TestDelegation_GetWhitelistForMerge(t *testing.T) { return vmcommon.Ok }}, nil }} - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) _ = eei.SetSystemSCContainer(systemSCContainerStub) @@ -4567,7 +4567,7 @@ func TestDelegation_AddTokens(t *testing.T) { args := createMockArgumentsForDelegation() eei := createDefaultEei() eei.inputParser = &mock.ArgumentParserMock{} - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) args.Eei = eei d, _ := NewDelegationSystemSC(args) @@ -4595,7 +4595,7 @@ func TestDelegation_correctNodesStatus(t *testing.T) { d, eei := createDelegationContractAndEEI() vmInput := getDefaultVmInputForFunc("correctNodesStatus", nil) - enableEpochsHandler, _ := d.enableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := d.enableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsAddTokensToDelegationFlagEnabledField = false returnCode := d.Execute(vmInput) assert.Equal(t, vmcommon.UserError, returnCode) @@ -4728,7 +4728,7 @@ func createDefaultEeiArgs() VMContextArgs { InputParser: parsers.NewCallArgsParser(), ValidatorAccountsDB: &stateMock.AccountsStub{}, ChanceComputer: &mock.RaterMock{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsMultiClaimOnDelegationEnabledField: true, }, } @@ -4756,13 +4756,13 @@ func TestDelegationSystemSC_ExecuteChangeOwnerUserErrors(t *testing.T) { args.Eei = eei d, _ := NewDelegationSystemSC(args) - args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub).IsChangeDelegationOwnerFlagEnabledField = false + args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub).IsChangeDelegationOwnerFlagEnabledField = false vmInput := getDefaultVmInputForFunc("changeOwner", vmInputArgs) output := d.Execute(vmInput) assert.Equal(t, vmcommon.UserError, output) assert.True(t, strings.Contains(eei.returnMessage, vmInput.Function+" is an unknown function")) - args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub).IsChangeDelegationOwnerFlagEnabledField = true + args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub).IsChangeDelegationOwnerFlagEnabledField = true vmInput.CallValue = big.NewInt(0) vmInput.CallerAddr = []byte("aaa") output = d.Execute(vmInput) @@ -4814,7 +4814,7 @@ func TestDelegationSystemSC_ExecuteChangeOwner(t *testing.T) { ChanceComputer: &mock.RaterMock{}, EnableEpochsHandler: args.EnableEpochsHandler, } - args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub).IsChangeDelegationOwnerFlagEnabledField = true + args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub).IsChangeDelegationOwnerFlagEnabledField = true eei, err := NewVMContext(argsVmContext) require.Nil(t, err) diff --git a/vm/systemSmartContracts/eei_test.go b/vm/systemSmartContracts/eei_test.go index cbdb3e8de82..c45b9dc16c0 100644 --- a/vm/systemSmartContracts/eei_test.go +++ b/vm/systemSmartContracts/eei_test.go @@ -9,6 +9,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/mock" @@ -97,7 +100,12 @@ func TestVmContext_GetBalance(t *testing.T) { addr := []byte("addr") balance := big.NewInt(10) - account, _ := state.NewUserAccount([]byte("123")) + argsAccCreation := state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } + account, _ := state.NewUserAccount([]byte("123"), argsAccCreation) _ = account.AddToBalance(balance) blockChainHook := &mock.BlockChainHookStub{GetUserAccountCalled: func(address []byte) (a vmcommon.UserAccountHandler, e error) { @@ -196,7 +204,7 @@ func TestVmContext_IsValidatorInvalidAccountTypeShouldRetFalse(t *testing.T) { args := createDefaultEeiArgs() args.ValidatorAccountsDB = &stateMock.AccountsStub{ GetExistingAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { - return state.NewEmptyUserAccount(), nil + return &stateMock.AccountWrapMock{}, nil }, } vmCtx, _ := NewVMContext(args) diff --git a/vm/systemSmartContracts/esdt_test.go b/vm/systemSmartContracts/esdt_test.go index e21d1c6302b..78440b1251b 100644 --- a/vm/systemSmartContracts/esdt_test.go +++ b/vm/systemSmartContracts/esdt_test.go @@ -14,6 +14,7 @@ import ( vmData "github.com/multiversx/mx-chain-core-go/data/vm" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/mock" @@ -34,7 +35,7 @@ func createMockArgumentsForESDT() ArgsNewESDTSmartContract { Hasher: &hashingMocks.HasherMock{}, AddressPubKeyConverter: testscommon.NewPubkeyConverterMock(32), EndOfEpochSCAddress: vm.EndOfEpochAddress, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsESDTFlagEnabledField: true, IsGlobalMintBurnFlagEnabledField: true, IsMetaESDTSetFlagEnabledField: true, @@ -188,7 +189,7 @@ func TestEsdt_ExecuteIssueWithMultiNFTCreate(t *testing.T) { args := createMockArgumentsForESDT() eei := createDefaultEei() args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) e, _ := NewESDTSmartContract(args) vmInput := &vmcommon.ContractCallInput{ @@ -266,7 +267,7 @@ func TestEsdt_ExecuteIssueWithZero(t *testing.T) { args := createMockArgumentsForESDT() eei := createDefaultEei() args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) e, _ := NewESDTSmartContract(args) vmInput := &vmcommon.ContractCallInput{ @@ -480,7 +481,7 @@ func TestEsdt_ExecuteBurnAndMintDisabled(t *testing.T) { t.Parallel() args := createMockArgumentsForESDT() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsGlobalMintBurnFlagEnabledField = false eei := createDefaultEei() args.Eei = eei @@ -882,7 +883,7 @@ func TestEsdt_ExecuteIssueDisabled(t *testing.T) { t.Parallel() args := createMockArgumentsForESDT() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsESDTFlagEnabledField = false e, _ := NewESDTSmartContract(args) @@ -2937,7 +2938,7 @@ func TestEsdt_SetSpecialRoleTransferNotEnabledShouldErr(t *testing.T) { t.Parallel() args := createMockArgumentsForESDT() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsESDTTransferRoleFlagEnabledField = false token := &ESDTDataV2{ @@ -3029,7 +3030,7 @@ func TestEsdt_SetSpecialRoleTransferWithTransferRoleEnhancement(t *testing.T) { t.Parallel() args := createMockArgumentsForESDT() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsESDTTransferRoleFlagEnabledField = false token := &ESDTDataV2{ @@ -3122,7 +3123,7 @@ func TestEsdt_SendAllTransferRoleAddresses(t *testing.T) { t.Parallel() args := createMockArgumentsForESDT() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsESDTMetadataContinuousCleanupFlagEnabledField = false token := &ESDTDataV2{ @@ -3949,7 +3950,7 @@ func TestEsdt_ExecuteIssueMetaESDT(t *testing.T) { args := createMockArgumentsForESDT() eei := createDefaultEei() args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) e, _ := NewESDTSmartContract(args) enableEpochsHandler.IsMetaESDTSetFlagEnabledField = false @@ -3998,7 +3999,7 @@ func TestEsdt_ExecuteChangeSFTToMetaESDT(t *testing.T) { args := createMockArgumentsForESDT() eei := createDefaultEei() args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) e, _ := NewESDTSmartContract(args) enableEpochsHandler.IsMetaESDTSetFlagEnabledField = false @@ -4085,7 +4086,7 @@ func TestEsdt_ExecuteRegisterAndSetErrors(t *testing.T) { args := createMockArgumentsForESDT() eei := createDefaultEei() args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) e, _ := NewESDTSmartContract(args) enableEpochsHandler.IsESDTRegisterAndSetAllRolesFlagEnabledField = false @@ -4212,7 +4213,7 @@ func TestEsdt_ExecuteRegisterAndSetMetaESDTShouldSetType(t *testing.T) { func registerAndSetAllRolesWithTypeCheck(t *testing.T, typeArgument []byte, expectedType []byte) { args := createMockArgumentsForESDT() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) eei := createDefaultEei() args.Eei = eei e, _ := NewESDTSmartContract(args) @@ -4243,7 +4244,7 @@ func TestEsdt_setBurnRoleGlobally(t *testing.T) { t.Parallel() args := createMockArgumentsForESDT() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) eei := createDefaultEei() args.Eei = eei @@ -4303,7 +4304,7 @@ func TestEsdt_unsetBurnRoleGlobally(t *testing.T) { t.Parallel() args := createMockArgumentsForESDT() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) eei := createDefaultEei() args.Eei = eei @@ -4371,7 +4372,7 @@ func TestEsdt_CheckRolesOnMetaESDT(t *testing.T) { t.Parallel() args := createMockArgumentsForESDT() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) eei := createDefaultEei() args.Eei = eei e, _ := NewESDTSmartContract(args) diff --git a/vm/systemSmartContracts/governance_test.go b/vm/systemSmartContracts/governance_test.go index 82143331b06..98134db754e 100644 --- a/vm/systemSmartContracts/governance_test.go +++ b/vm/systemSmartContracts/governance_test.go @@ -4,20 +4,20 @@ import ( "bytes" "errors" "fmt" - "github.com/multiversx/mx-chain-go/process/smartContract/hooks" - stateMock "github.com/multiversx/mx-chain-go/testscommon/state" - "github.com/stretchr/testify/assert" "math/big" "strings" "testing" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/config" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/process/smartContract/hooks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/mock" vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -57,7 +57,7 @@ func createArgsWithEEI(eei vm.SystemEI) ArgsNewGovernanceContract { ValidatorSCAddress: vm.ValidatorSCAddress, OwnerAddress: bytes.Repeat([]byte{1}, 32), UnBondPeriodInEpochs: 10, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsGovernanceFlagEnabledField: true, }, } @@ -70,7 +70,7 @@ func createEEIWithBlockchainHook(blockchainHook vm.BlockchainHook) vm.ContextHan InputParser: &mock.ArgumentParserMock{}, ValidatorAccountsDB: &stateMock.AccountsStub{}, ChanceComputer: &mock.RaterMock{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, }) systemSCContainerStub := &mock.SystemSCContainerStub{GetCalled: func(key []byte) (vm.SystemSmartContract, error) { return &mock.SystemSCStub{ExecuteCalled: func(args *vmcommon.ContractCallInput) vmcommon.ReturnCode { @@ -298,7 +298,7 @@ func TestGovernanceContract_ExecuteInitV2(t *testing.T) { t.Parallel() args := createMockGovernanceArgs() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) gsc, _ := NewGovernanceContract(args) callInput := createVMInput(big.NewInt(0), "initV2", vm.GovernanceSCAddress, []byte("addr2"), nil) diff --git a/vm/systemSmartContracts/staking_test.go b/vm/systemSmartContracts/staking_test.go index a84cd7e7e70..539315a30d3 100644 --- a/vm/systemSmartContracts/staking_test.go +++ b/vm/systemSmartContracts/staking_test.go @@ -18,7 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/mock" @@ -52,7 +52,7 @@ func createMockStakingScArgumentsWithSystemScAddresses( ActivateBLSPubKeyMessageVerification: false, MinUnstakeTokensValue: "1", }, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsStakeFlagEnabledField: true, IsCorrectLastUnJailedFlagEnabledField: true, IsCorrectFirstQueuedFlagEnabledField: true, @@ -1002,7 +1002,7 @@ func TestStakingSc_StakeWithV1ShouldWork(t *testing.T) { stakingAccessAddress := []byte("stakingAccessAddress") args := createMockStakingScArguments() args.StakingSCConfig.MinStakeValue = stakeValue.Text(10) - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakeFlagEnabledField = false args.StakingAccessAddr = stakingAccessAddress args.Eei = eei @@ -1107,7 +1107,7 @@ func TestStakingSc_ExecuteStakeStakeJailAndSwitch(t *testing.T) { args.StakingAccessAddr = stakingAccessAddress args.StakingSCConfig.MinStakeValue = stakeValue.Text(10) args.StakingSCConfig.MaxNumberOfNodesForStake = 2 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true enableEpochsHandler.IsCorrectJailedNotUnStakedEmptyQueueFlagEnabledField = false args.Eei = eei @@ -1249,7 +1249,7 @@ func TestStakingSc_ExecuteStakeStakeJailAndSwitchWithBoundaries(t *testing.T) { eei := createDefaultEei() eei.blockChainHook = blockChainHook args := createStakingSCArgs(eei, stakingAccessAddress, stakeValue, maxStakedNodesNumber) - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsCorrectJailedNotUnStakedEmptyQueueFlagEnabledField = tt.flagJailedRemoveEnabled stakingSmartContract, _ := NewStakingSmartContract(args) @@ -1307,7 +1307,7 @@ func createStakingSCArgs(eei *vmContext, stakingAccessAddress []byte, stakeValue args.StakingAccessAddr = stakingAccessAddress args.StakingSCConfig.MinStakeValue = stakeValue.Text(10) args.StakingSCConfig.MaxNumberOfNodesForStake = uint64(maxStakedNodesNumber) - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.Eei = eei return args @@ -1331,7 +1331,7 @@ func TestStakingSc_ExecuteStakeStakeStakeJailJailUnJailTwice(t *testing.T) { args.StakingAccessAddr = stakingAccessAddress args.StakingSCConfig.MinStakeValue = stakeValue.Text(10) args.StakingSCConfig.MaxNumberOfNodesForStake = 2 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.Eei = eei stakingSmartContract, _ := NewStakingSmartContract(args) @@ -1459,7 +1459,7 @@ func TestStakingSc_ExecuteStakeUnStakeJailCombinations(t *testing.T) { args.StakingAccessAddr = stakingAccessAddress args.StakingSCConfig.MinStakeValue = stakeValue.Text(10) args.StakingSCConfig.MaxNumberOfNodesForStake = 2 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.Eei = eei stakingSmartContract, _ := NewStakingSmartContract(args) @@ -1769,7 +1769,7 @@ func TestStakingSc_updateConfigMaxNodesOK(t *testing.T) { stakingAccessAddress := []byte("stakingAccessAddress") args := createMockStakingScArguments() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.StakingAccessAddr = stakingAccessAddress args.StakingSCConfig.MinStakeValue = stakeValue.Text(10) @@ -1841,7 +1841,7 @@ func TestStakingSC_SetOwnersOnAddressesWrongCallerShouldErr(t *testing.T) { t.Parallel() args := createMockStakingScArguments() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true blockChainHook := &mock.BlockChainHookStub{} blockChainHook.GetStorageDataCalled = func(accountsAddress []byte, index []byte) ([]byte, uint32, error) { @@ -1865,7 +1865,7 @@ func TestStakingSC_SetOwnersOnAddressesWrongArgumentsShouldErr(t *testing.T) { t.Parallel() args := createMockStakingScArguments() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true blockChainHook := &mock.BlockChainHookStub{} blockChainHook.GetStorageDataCalled = func(accountsAddress []byte, index []byte) ([]byte, uint32, error) { @@ -1890,7 +1890,7 @@ func TestStakingSC_SetOwnersOnAddressesShouldWork(t *testing.T) { t.Parallel() args := createMockStakingScArguments() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true blockChainHook := &mock.BlockChainHookStub{} blockChainHook.GetStorageDataCalled = func(accountsAddress []byte, index []byte) ([]byte, uint32, error) { @@ -1929,7 +1929,7 @@ func TestStakingSC_SetOwnersOnAddressesEmptyArgsShouldWork(t *testing.T) { t.Parallel() args := createMockStakingScArguments() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true blockChainHook := &mock.BlockChainHookStub{} blockChainHook.GetStorageDataCalled = func(accountsAddress []byte, index []byte) ([]byte, uint32, error) { @@ -1974,7 +1974,7 @@ func TestStakingSC_GetOwnerWrongCallerShouldErr(t *testing.T) { t.Parallel() args := createMockStakingScArguments() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true blockChainHook := &mock.BlockChainHookStub{} blockChainHook.GetStorageDataCalled = func(accountsAddress []byte, index []byte) ([]byte, uint32, error) { @@ -1998,7 +1998,7 @@ func TestStakingSC_GetOwnerWrongArgumentsShouldErr(t *testing.T) { t.Parallel() args := createMockStakingScArguments() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true blockChainHook := &mock.BlockChainHookStub{} blockChainHook.GetStorageDataCalled = func(accountsAddress []byte, index []byte) ([]byte, uint32, error) { @@ -2022,7 +2022,7 @@ func TestStakingSC_GetOwnerShouldWork(t *testing.T) { t.Parallel() args := createMockStakingScArguments() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true blockChainHook := &mock.BlockChainHookStub{} blockChainHook.GetStorageDataCalled = func(accountsAddress []byte, index []byte) ([]byte, uint32, error) { @@ -2072,7 +2072,7 @@ func TestStakingSc_StakeFromQueue(t *testing.T) { args := createMockStakingScArguments() args.StakingAccessAddr = stakingAccessAddress args.StakingSCConfig.MaxNumberOfNodesForStake = 1 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.Eei = eei args.StakingSCConfig.UnBondPeriod = 100 @@ -2220,7 +2220,7 @@ func TestStakingSC_ResetWaitingListUnJailed(t *testing.T) { args.StakingAccessAddr = stakingAccessAddress args.StakingSCConfig.MinStakeValue = stakeValue.Text(10) args.StakingSCConfig.MaxNumberOfNodesForStake = 1 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.Eei = eei stakingSmartContract, _ := NewStakingSmartContract(args) @@ -2282,7 +2282,7 @@ func TestStakingSc_UnStakeNodeWhenMaxNumIsMoreShouldNotStakeFromWaiting(t *testi args.StakingSCConfig.MinStakeValue = stakeValue.Text(10) args.StakingSCConfig.MaxNumberOfNodesForStake = 2 args.MinNumNodes = 1 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.Eei = eei stakingSmartContract, _ := NewStakingSmartContract(args) @@ -2315,7 +2315,7 @@ func TestStakingSc_ChangeRewardAndOwnerAddress(t *testing.T) { stakingAccessAddress := []byte("stakingAccessAddress") args := createMockStakingScArguments() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) args.StakingAccessAddr = stakingAccessAddress args.Eei = eei sc, _ := NewStakingSmartContract(args) @@ -2429,7 +2429,7 @@ func TestStakingSc_RemoveFromWaitingListFirst(t *testing.T) { args := createMockStakingScArguments() args.Marshalizer = marshalizer args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsCorrectFirstQueuedFlagEnabledField = tt.flag sc, _ := NewStakingSmartContract(args) err := sc.removeFromWaitingList(firstBLS) @@ -2479,7 +2479,7 @@ func TestStakingSc_RemoveFromWaitingListSecondThatLooksLikeFirstBeforeFix(t *tes args := createMockStakingScArguments() args.Marshalizer = marshalizer args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsCorrectFirstQueuedFlagEnabledField = false sc, _ := NewStakingSmartContract(args) @@ -2628,7 +2628,7 @@ func TestStakingSc_InsertAfterLastJailedBeforeFix(t *testing.T) { args := createMockStakingScArguments() args.Marshalizer = marshalizer args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsCorrectFirstQueuedFlagEnabledField = false sc, _ := NewStakingSmartContract(args) err := sc.insertAfterLastJailed(waitingListHead, jailedBLS) @@ -2798,7 +2798,7 @@ func TestStakingSc_fixWaitingListQueueSize(t *testing.T) { } sc, eei, marshalizer, _ := makeWrongConfigForWaitingBlsKeysList(t, waitingBlsKeys) alterWaitingListLength(t, eei, marshalizer) - enableEpochsHandler, _ := sc.enableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := sc.enableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsCorrectFirstQueuedFlagEnabledField = false eei.SetGasProvided(500000000) @@ -3247,7 +3247,7 @@ func TestStakingSc_fixMissingNodeOnQueue(t *testing.T) { arguments.Arguments = make([][]byte, 0) eei.returnMessage = "" - enableEpochsHandler, _ := sc.enableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := sc.enableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsCorrectFirstQueuedFlagEnabledField = false retCode := sc.Execute(arguments) assert.Equal(t, vmcommon.UserError, retCode) diff --git a/vm/systemSmartContracts/validator_test.go b/vm/systemSmartContracts/validator_test.go index c66873e8596..471bd79606a 100644 --- a/vm/systemSmartContracts/validator_test.go +++ b/vm/systemSmartContracts/validator_test.go @@ -18,7 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/mock" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -58,7 +58,7 @@ func createMockArgumentsForValidatorSCWithSystemScAddresses( DelegationMgrSCAddress: vm.DelegationManagerSCAddress, GovernanceSCAddress: vm.GovernanceSCAddress, ShardCoordinator: &mock.ShardCoordinatorStub{}, - EnableEpochsHandler: &testscommon.EnableEpochsHandlerStub{ + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsStakeFlagEnabledField: true, IsUnBondTokensV2FlagEnabledField: true, IsValidatorToDelegationFlagEnabledField: true, @@ -425,7 +425,7 @@ func TestStakingValidatorSC_ExecuteStakeDoubleKeyAndCleanup(t *testing.T) { args.Eei = eei args.StakingSCConfig = argsStaking.StakingSCConfig - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsDoubleKeyProtectionFlagEnabledField = false validatorSc, _ := NewValidatorSmartContract(args) @@ -662,7 +662,7 @@ func TestStakingValidatorSC_ExecuteStakeStakeTokensUnBondRestakeUnStake(t *testi blockChainHook := &mock.BlockChainHookStub{} args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true atArgParser := parsers.NewCallArgsParser() @@ -674,7 +674,7 @@ func TestStakingValidatorSC_ExecuteStakeStakeTokensUnBondRestakeUnStake(t *testi argsStaking.StakingSCConfig.GenesisNodePrice = "10000000" argsStaking.Eei = eei argsStaking.StakingSCConfig.UnBondPeriod = 1 - stubStaking, _ := argsStaking.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + stubStaking, _ := argsStaking.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) stubStaking.IsStakingV2FlagEnabledField = true argsStaking.MinNumNodes = 0 stakingSc, _ := NewStakingSmartContract(argsStaking) @@ -933,7 +933,7 @@ func TestStakingValidatorSC_ExecuteStakeUnStake1Stake1More(t *testing.T) { argsStaking.MinNumNodes = 0 argsStaking.Eei = eei argsStaking.StakingSCConfig.UnBondPeriod = 100000 - stubStaking, _ := argsStaking.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + stubStaking, _ := argsStaking.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) stubStaking.IsStakingV2FlagEnabledField = true stakingSc, _ := NewStakingSmartContract(argsStaking) @@ -949,7 +949,7 @@ func TestStakingValidatorSC_ExecuteStakeUnStake1Stake1More(t *testing.T) { staker := []byte("staker") args.Eei = eei args.StakingSCConfig = argsStaking.StakingSCConfig - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true sc, _ := NewValidatorSmartContract(args) arguments := CreateVmContractCallInput() @@ -1209,7 +1209,7 @@ func TestStakingValidatorSC_StakeUnStake3XRestake2(t *testing.T) { blockChainHook := &mock.BlockChainHookStub{} args := createMockArgumentsForValidatorSC() args.StakingSCConfig.MaxNumberOfNodesForStake = 1 - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true atArgParser := parsers.NewCallArgsParser() eei := createDefaultEei() @@ -1220,7 +1220,7 @@ func TestStakingValidatorSC_StakeUnStake3XRestake2(t *testing.T) { argsStaking.StakingSCConfig.GenesisNodePrice = "10000000" argsStaking.Eei = eei argsStaking.StakingSCConfig.UnBondPeriod = 100000 - stubStaking, _ := argsStaking.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + stubStaking, _ := argsStaking.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) stubStaking.IsStakingV2FlagEnabledField = true stakingSc, _ := NewStakingSmartContract(argsStaking) eei.SetSCAddress([]byte("addr")) @@ -1285,7 +1285,7 @@ func TestStakingValidatorSC_StakeShouldSetOwnerIfStakingV2IsEnabled(t *testing.T blsKey := []byte("blsKey") args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.StakingSCConfig.MaxNumberOfNodesForStake = 1 atArgParser := parsers.NewCallArgsParser() @@ -1297,7 +1297,7 @@ func TestStakingValidatorSC_StakeShouldSetOwnerIfStakingV2IsEnabled(t *testing.T argsStaking.Eei = eei eei.SetSCAddress(args.ValidatorSCAddress) argsStaking.StakingSCConfig.UnBondPeriod = 100000 - stubStaking, _ := argsStaking.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + stubStaking, _ := argsStaking.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) stubStaking.IsStakingV2FlagEnabledField = true stakingSc, _ := NewStakingSmartContract(argsStaking) _ = eei.SetSystemSCContainer(&mock.SystemSCContainerStub{GetCalled: func(key []byte) (contract vm.SystemSmartContract, err error) { @@ -2407,7 +2407,7 @@ func TestValidatorStakingSC_ExecuteStakeUnStakeReturnsErrAsNotEnabled(t *testing }} } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakeFlagEnabledField = false args.Eei = eei @@ -2501,7 +2501,7 @@ func TestValidatorSC_ExecuteUnBondBeforePeriodEndsForV2(t *testing.T) { }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.StakingSCConfig.UnBondPeriod = 1000 eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) @@ -2668,7 +2668,7 @@ func TestValidatorStakingSC_ExecuteUnStakeAndUnBondStake(t *testing.T) { args.Eei = eei args.StakingSCConfig.UnBondPeriod = unBondPeriod args.StakingSCConfig.GenesisNodePrice = valueStakedByTheCaller.Text(10) - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true argsStaking := createMockStakingScArguments() @@ -3152,7 +3152,7 @@ func TestValidatorStakingSC_ChangeRewardAddress(t *testing.T) { nodesToRunBytes := big.NewInt(1).Bytes() blockChainHook := &mock.BlockChainHookStub{} args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsValidatorToDelegationFlagEnabledField = false eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -3245,7 +3245,7 @@ func TestStakingValidatorSC_UnstakeTokensInvalidArgumentsShouldError(t *testing. unbondPeriod := uint64(10) blockChainHook := &mock.BlockChainHookStub{} args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -3277,7 +3277,7 @@ func TestStakingValidatorSC_UnstakeTokensWithCallValueShouldError(t *testing.T) unbondPeriod := uint64(10) blockChainHook := &mock.BlockChainHookStub{} args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -3303,7 +3303,7 @@ func TestStakingValidatorSC_UnstakeTokensOverMaxShouldUnStake(t *testing.T) { }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -3347,7 +3347,7 @@ func TestStakingValidatorSC_UnstakeTokensUnderMinimumAllowedShouldErr(t *testing }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.StakingSCConfig.MinUnstakeTokensValue = "2" eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) @@ -3389,7 +3389,7 @@ func TestStakingValidatorSC_UnstakeAllTokensWithActiveNodesShouldError(t *testin }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.MinDeposit = "1000" eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) @@ -3431,7 +3431,7 @@ func TestStakingValidatorSC_UnstakeTokensShouldWork(t *testing.T) { }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -3498,7 +3498,7 @@ func TestStakingValidatorSC_UnstakeTokensHavingUnstakedShouldWork(t *testing.T) }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -3570,7 +3570,7 @@ func TestStakingValidatorSC_UnstakeAllTokensShouldWork(t *testing.T) { }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, uint64(unbondPeriod), blockChainHook) args.Eei = eei @@ -3650,7 +3650,7 @@ func TestStakingValidatorSC_UnbondTokensOneArgument(t *testing.T) { }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.StakingSCConfig.UnBondPeriodInEpochs = unbondPeriod eei := createVmContextWithStakingSc(minStakeValue, uint64(unbondPeriod), blockChainHook) @@ -3730,7 +3730,7 @@ func TestStakingValidatorSC_UnbondTokensWithCallValueShouldError(t *testing.T) { unbondPeriod := uint64(10) blockChainHook := &mock.BlockChainHookStub{} args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -3757,7 +3757,7 @@ func TestStakingValidatorSC_UnBondTokensV1ShouldWork(t *testing.T) { }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true enableEpochsHandler.IsUnBondTokensV2FlagEnabledField = false args.StakingSCConfig.UnBondPeriodInEpochs = unbondPeriod @@ -3839,7 +3839,7 @@ func TestStakingValidatorSC_UnBondTokensV2ShouldWork(t *testing.T) { }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.StakingSCConfig.UnBondPeriodInEpochs = unbondPeriod eei := createVmContextWithStakingSc(minStakeValue, uint64(unbondPeriod), blockChainHook) @@ -3920,7 +3920,7 @@ func TestStakingValidatorSC_UnBondTokensV2WithTooMuchToUnbondShouldWork(t *testi }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.StakingSCConfig.UnBondPeriodInEpochs = unbondPeriod eei := createVmContextWithStakingSc(minStakeValue, uint64(unbondPeriod), blockChainHook) @@ -4002,7 +4002,7 @@ func TestStakingValidatorSC_UnBondTokensV2WithSplitShouldWork(t *testing.T) { }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.StakingSCConfig.UnBondPeriodInEpochs = unbondPeriod eei := createVmContextWithStakingSc(minStakeValue, uint64(unbondPeriod), blockChainHook) @@ -4092,7 +4092,7 @@ func TestStakingValidatorSC_UnBondAllTokensWithMinDepositShouldError(t *testing. }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.MinDeposit = "1000" args.StakingSCConfig.UnBondPeriodInEpochs = unbondPeriod @@ -4141,7 +4141,7 @@ func TestStakingValidatorSC_UnBondAllTokensShouldWork(t *testing.T) { }, } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.StakingSCConfig.UnBondPeriodInEpochs = unbondPeriod eei := createVmContextWithStakingSc(minStakeValue, uint64(unbondPeriod), blockChainHook) @@ -4243,7 +4243,7 @@ func TestStakingValidatorSC_GetTopUpTotalStakedWithValueShouldError(t *testing.T unbondPeriod := uint64(10) blockChainHook := &mock.BlockChainHookStub{} args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -4262,7 +4262,7 @@ func TestStakingValidatorSC_GetTopUpTotalStakedInsufficientGasShouldError(t *tes unbondPeriod := uint64(10) blockChainHook := &mock.BlockChainHookStub{} args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -4282,7 +4282,7 @@ func TestStakingValidatorSC_GetTopUpTotalStakedCallerDoesNotExistShouldError(t * unbondPeriod := uint64(10) blockChainHook := &mock.BlockChainHookStub{} args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -4301,7 +4301,7 @@ func TestStakingValidatorSC_GetTopUpTotalStakedShouldWork(t *testing.T) { unbondPeriod := uint64(10) blockChainHook := &mock.BlockChainHookStub{} args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unbondPeriod, blockChainHook) args.Eei = eei @@ -4385,7 +4385,7 @@ func TestStakingValidatorSC_UnStakeUnBondFromWaitingList(t *testing.T) { argsStaking.StakingSCConfig.GenesisNodePrice = "10000000" argsStaking.Eei = eei argsStaking.StakingSCConfig.UnBondPeriod = 100000 - stubStaking, _ := argsStaking.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + stubStaking, _ := argsStaking.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) stubStaking.IsStakingV2FlagEnabledField = true argsStaking.StakingSCConfig.MaxNumberOfNodesForStake = 1 stakingSc, _ := NewStakingSmartContract(argsStaking) @@ -4397,7 +4397,7 @@ func TestStakingValidatorSC_UnStakeUnBondFromWaitingList(t *testing.T) { args := createMockArgumentsForValidatorSC() args.StakingSCConfig = argsStaking.StakingSCConfig args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true sc, _ := NewValidatorSmartContract(args) @@ -4464,7 +4464,7 @@ func TestStakingValidatorSC_StakeUnStakeUnBondTokensNoNodes(t *testing.T) { argsStaking.StakingSCConfig.GenesisNodePrice = "10000000" argsStaking.Eei = eei argsStaking.StakingSCConfig.UnBondPeriod = 100000 - stubStaking, _ := argsStaking.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + stubStaking, _ := argsStaking.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) stubStaking.IsStakingV2FlagEnabledField = true argsStaking.StakingSCConfig.MaxNumberOfNodesForStake = 1 stakingSc, _ := NewStakingSmartContract(argsStaking) @@ -4475,7 +4475,7 @@ func TestStakingValidatorSC_StakeUnStakeUnBondTokensNoNodes(t *testing.T) { args := createMockArgumentsForValidatorSC() args.StakingSCConfig = argsStaking.StakingSCConfig - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true args.Eei = eei @@ -4523,7 +4523,7 @@ func TestValidatorStakingSC_UnStakeUnBondPaused(t *testing.T) { } args := createMockArgumentsForValidatorSC() - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true eei := createVmContextWithStakingSc(minStakeValue, unboundPeriod, blockChainHook) args.Eei = eei @@ -4594,7 +4594,7 @@ func TestValidatorSC_getUnStakedTokensList_InvalidArgumentsCountShouldErr(t *tes }, } args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true stakingValidatorSc, _ := NewValidatorSmartContract(args) @@ -4623,7 +4623,7 @@ func TestValidatorSC_getUnStakedTokensList_CallValueNotZeroShouldErr(t *testing. }, } args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true stakingValidatorSc, _ := NewValidatorSmartContract(args) @@ -4686,7 +4686,7 @@ func TestValidatorSC_getUnStakedTokensList(t *testing.T) { args := createMockArgumentsForValidatorSC() args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true stakingValidatorSc, _ := NewValidatorSmartContract(args) @@ -4719,7 +4719,7 @@ func TestValidatorSC_getMinUnStakeTokensValueDelegationManagerNotActive(t *testi eei := &mock.SystemEIStub{} args := createMockArgumentsForValidatorSC() args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsDelegationManagerFlagEnabledField = false args.StakingSCConfig.MinUnstakeTokensValue = fmt.Sprintf("%d", minUnstakeTokens) @@ -4746,7 +4746,7 @@ func TestValidatorSC_getMinUnStakeTokensValueFromDelegationManager(t *testing.T) } args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsDelegationManagerFlagEnabledField = true args.StakingSCConfig.MinUnstakeTokensValue = fmt.Sprintf("%d", minUnstakeTokens) @@ -4766,7 +4766,7 @@ func TestStakingValidatorSC_checkInputArgsForValidatorToDelegationErrors(t *test eei.inputParser = atArgParser args := createMockArgumentsForValidatorSC() args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) sc, _ := NewValidatorSmartContract(args) @@ -4905,7 +4905,7 @@ func TestStakingValidatorSC_ChangeOwnerOfValidatorData(t *testing.T) { argsStaking := createMockStakingScArguments() argsStaking.Eei = eei - enableEpochsHandler, _ := argsStaking.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := argsStaking.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true stakingSc, _ := NewStakingSmartContract(argsStaking) eei.SetSCAddress([]byte("addr")) @@ -5005,7 +5005,7 @@ func TestStakingValidatorSC_MergeValidatorData(t *testing.T) { argsStaking := createMockStakingScArguments() argsStaking.Eei = eei - enableEpochsHandler, _ := argsStaking.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := argsStaking.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsStakingV2FlagEnabledField = true stakingSc, _ := NewStakingSmartContract(argsStaking) eei.SetSCAddress([]byte("addr")) @@ -5114,7 +5114,7 @@ func TestValidatorSC_getMinUnStakeTokensValueFromDelegationManagerMarshalizerFai } args.Eei = eei - enableEpochsHandler, _ := args.EnableEpochsHandler.(*testscommon.EnableEpochsHandlerStub) + enableEpochsHandler, _ := args.EnableEpochsHandler.(*enableEpochsHandlerMock.EnableEpochsHandlerStub) enableEpochsHandler.IsDelegationManagerFlagEnabledField = true args.StakingSCConfig.MinUnstakeTokensValue = fmt.Sprintf("%d", minUnstakeTokens)