diff --git a/common/channels.go b/common/channels.go index 177ac89f5c5..ba240d76b7b 100644 --- a/common/channels.go +++ b/common/channels.go @@ -1,5 +1,7 @@ package common +import "github.com/multiversx/mx-chain-core-go/core" + // GetClosedUnbufferedChannel returns an instance of a 'chan struct{}' that is already closed func GetClosedUnbufferedChannel() chan struct{} { ch := make(chan struct{}) @@ -7,3 +9,10 @@ func GetClosedUnbufferedChannel() chan struct{} { return ch } + +// CloseKeyValueHolderChan will close the channel if not nil +func CloseKeyValueHolderChan(ch chan core.KeyValueHolder) { + if ch != nil { + close(ch) + } +} diff --git a/common/constants.go b/common/constants.go index 31e75da61ee..521ef905d8e 100644 --- a/common/constants.go +++ b/common/constants.go @@ -807,6 +807,10 @@ const ( // TrieLeavesChannelDefaultCapacity represents the default value to be used as capacity for getting all trie leaves on // a channel TrieLeavesChannelDefaultCapacity = 100 + + // TrieLeavesChannelSyncCapacity represents the value to be used as capacity for getting main trie + // leaf nodes for trie sync + TrieLeavesChannelSyncCapacity = 1000 ) // ApiOutputFormat represents the format type returned by api diff --git a/state/syncer/baseAccountsSyncer.go b/state/syncer/baseAccountsSyncer.go index 7bd821eedc1..a01f1155fed 100644 --- a/state/syncer/baseAccountsSyncer.go +++ b/state/syncer/baseAccountsSyncer.go @@ -27,7 +27,6 @@ type baseAccountsSyncer struct { timeoutHandler trie.TimeoutHandler shardId uint32 cacher storage.Cacher - rootHash []byte maxTrieLevelInMemory uint name string maxHardCapForMissingNodes int @@ -96,15 +95,11 @@ func (b *baseAccountsSyncer) syncMainTrie( rootHash []byte, trieTopic string, ctx context.Context, -) (common.Trie, error) { - b.rootHash = rootHash + leavesChan chan core.KeyValueHolder, +) error { atomic.AddInt32(&b.numMaxTries, 1) log.Trace("syncing main trie", "roothash", rootHash) - dataTrie, err := trie.NewTrie(b.trieStorageManager, b.marshalizer, b.hasher, b.enableEpochsHandler, b.maxTrieLevelInMemory) - if err != nil { - return nil, err - } b.dataTries[string(rootHash)] = struct{}{} arg := trie.ArgTrieSyncer{ @@ -119,22 +114,23 @@ func (b *baseAccountsSyncer) syncMainTrie( TimeoutHandler: b.timeoutHandler, MaxHardCapForMissingNodes: b.maxHardCapForMissingNodes, CheckNodesOnDisk: b.checkNodesOnDisk, + LeavesChan: leavesChan, } trieSyncer, err := trie.CreateTrieSyncer(arg, b.trieSyncerVersion) if err != nil { - return nil, err + return err } err = trieSyncer.StartSyncing(rootHash, ctx) if err != nil { - return nil, err + return err } atomic.AddInt32(&b.numTriesSynced, 1) log.Trace("finished syncing main trie", "roothash", rootHash) - return dataTrie.Recreate(rootHash) + return nil } func (b *baseAccountsSyncer) printStatisticsAndUpdateMetrics(ctx context.Context) { diff --git a/state/syncer/baseAccoutnsSyncer_test.go b/state/syncer/baseAccoutnsSyncer_test.go new file mode 100644 index 00000000000..da3819b05ce --- /dev/null +++ b/state/syncer/baseAccoutnsSyncer_test.go @@ -0,0 +1,116 @@ +package syncer_test + +import ( + "testing" + "time" + + "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/syncer" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" + "github.com/multiversx/mx-chain-go/testscommon/storageManager" + "github.com/stretchr/testify/require" +) + +func getDefaultBaseAccSyncerArgs() syncer.ArgsNewBaseAccountsSyncer { + return syncer.ArgsNewBaseAccountsSyncer{ + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: marshallerMock.MarshalizerMock{}, + TrieStorageManager: &storageManager.StorageManagerStub{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + Timeout: time.Second, + Cacher: testscommon.NewCacherMock(), + UserAccountsSyncStatisticsHandler: &testscommon.SizeSyncStatisticsHandlerStub{}, + AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + MaxTrieLevelInMemory: 5, + MaxHardCapForMissingNodes: 100, + TrieSyncerVersion: 3, + CheckNodesOnDisk: false, + } +} + +func TestBaseAccountsSyncer_CheckArgs(t *testing.T) { + t.Parallel() + + t.Run("nil hasher", func(t *testing.T) { + t.Parallel() + + args := getDefaultBaseAccSyncerArgs() + args.Hasher = nil + err := syncer.CheckBaseAccountsSyncerArgs(args) + require.Equal(t, state.ErrNilHasher, err) + }) + + t.Run("nil marshaller", func(t *testing.T) { + t.Parallel() + + args := getDefaultBaseAccSyncerArgs() + args.Marshalizer = nil + err := syncer.CheckBaseAccountsSyncerArgs(args) + require.Equal(t, state.ErrNilMarshalizer, err) + }) + + t.Run("nil trie storage manager", func(t *testing.T) { + t.Parallel() + + args := getDefaultBaseAccSyncerArgs() + args.TrieStorageManager = nil + err := syncer.CheckBaseAccountsSyncerArgs(args) + require.Equal(t, state.ErrNilStorageManager, err) + }) + + t.Run("nil requests handler", func(t *testing.T) { + t.Parallel() + + args := getDefaultBaseAccSyncerArgs() + args.RequestHandler = nil + err := syncer.CheckBaseAccountsSyncerArgs(args) + require.Equal(t, state.ErrNilRequestHandler, err) + }) + + t.Run("nil cacher", func(t *testing.T) { + t.Parallel() + + args := getDefaultBaseAccSyncerArgs() + args.Cacher = nil + err := syncer.CheckBaseAccountsSyncerArgs(args) + require.Equal(t, state.ErrNilCacher, err) + }) + + t.Run("nil user accounts sync statistics handler", func(t *testing.T) { + t.Parallel() + + args := getDefaultBaseAccSyncerArgs() + args.UserAccountsSyncStatisticsHandler = nil + err := syncer.CheckBaseAccountsSyncerArgs(args) + require.Equal(t, state.ErrNilSyncStatisticsHandler, err) + }) + + t.Run("nil app status handler", func(t *testing.T) { + t.Parallel() + + args := getDefaultBaseAccSyncerArgs() + args.AppStatusHandler = nil + err := syncer.CheckBaseAccountsSyncerArgs(args) + require.Equal(t, state.ErrNilAppStatusHandler, err) + }) + + t.Run("invalid max hard capacity for missing nodes", func(t *testing.T) { + t.Parallel() + + args := getDefaultBaseAccSyncerArgs() + args.MaxHardCapForMissingNodes = 0 + err := syncer.CheckBaseAccountsSyncerArgs(args) + require.Equal(t, state.ErrInvalidMaxHardCapForMissingNodes, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + require.Nil(t, syncer.CheckBaseAccountsSyncerArgs(getDefaultBaseAccSyncerArgs())) + }) +} diff --git a/state/syncer/export_test.go b/state/syncer/export_test.go index 1cfbb0aa96e..cfd917aba66 100644 --- a/state/syncer/export_test.go +++ b/state/syncer/export_test.go @@ -1,5 +1,30 @@ package syncer +import ( + "context" + + "github.com/multiversx/mx-chain-go/common" +) + +// UserAccountsSyncer - +type UserAccountsSyncer = userAccountsSyncer + +// ValidatorAccountsSyncer - +type ValidatorAccountsSyncer = validatorAccountsSyncer + +// CheckBaseAccountsSyncerArgs - +func CheckBaseAccountsSyncerArgs(args ArgsNewBaseAccountsSyncer) error { + return checkArgs(args) +} + +// SyncAccountDataTries - +func (u *userAccountsSyncer) SyncAccountDataTries( + leavesChannels *common.TrieIteratorChannels, + ctx context.Context, +) error { + return u.syncAccountDataTries(leavesChannels, ctx) +} + // GetNumHandlers - func (mtnn *missingTrieNodesNotifier) GetNumHandlers() int { return len(mtnn.handlers) diff --git a/state/syncer/userAccountsSyncer.go b/state/syncer/userAccountsSyncer.go index 7cb47eafb95..283e3c25b3e 100644 --- a/state/syncer/userAccountsSyncer.go +++ b/state/syncer/userAccountsSyncer.go @@ -15,9 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/trie" - "github.com/multiversx/mx-chain-go/trie/keyBuilder" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -86,7 +84,6 @@ func NewUserAccountsSyncer(args ArgsNewUserAccountsSyncer) (*userAccountsSyncer, timeoutHandler: timeoutHandler, shardId: args.ShardId, cacher: args.Cacher, - rootHash: nil, maxTrieLevelInMemory: args.MaxTrieLevelInMemory, name: fmt.Sprintf("user accounts for shard %s", core.GetShardIDString(args.ShardId)), maxHardCapForMissingNodes: args.MaxHardCapForMissingNodes, @@ -126,23 +123,40 @@ func (u *userAccountsSyncer) SyncAccounts(rootHash []byte, storageMarker common. go u.printStatisticsAndUpdateMetrics(ctx) - mainTrie, err := u.syncMainTrie(rootHash, factory.AccountTrieNodesTopic, ctx) - if err != nil { - return err + leavesChannels := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelSyncCapacity), + ErrChan: errChan.NewErrChanWrapper(), } - defer func() { - _ = mainTrie.Close() + wgSyncMainTrie := &sync.WaitGroup{} + wgSyncMainTrie.Add(1) + + go func() { + err := u.syncMainTrie(rootHash, factory.AccountTrieNodesTopic, ctx, leavesChannels.LeavesChan) + if err != nil { + leavesChannels.ErrChan.WriteInChanNonBlocking(err) + } + + common.CloseKeyValueHolderChan(leavesChannels.LeavesChan) + + wgSyncMainTrie.Done() }() - log.Debug("main trie synced, starting to sync data tries", "num data tries", len(u.dataTries)) + err := u.syncAccountDataTries(leavesChannels, ctx) + if err != nil { + return err + } + + wgSyncMainTrie.Wait() - err = u.syncAccountDataTries(mainTrie, ctx) + err = leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return err } - storageMarker.MarkStorerAsSyncedAndActive(mainTrie.GetStorageManager()) + storageMarker.MarkStorerAsSyncedAndActive(u.trieStorageManager) + + log.Debug("main trie and data tries synced", "main trie root hash", rootHash, "num data tries", len(u.dataTries)) return nil } @@ -185,6 +199,7 @@ func (u *userAccountsSyncer) createAndStartSyncer( TimeoutHandler: u.timeoutHandler, MaxHardCapForMissingNodes: u.maxHardCapForMissingNodes, CheckNodesOnDisk: checkNodesOnDisk, + LeavesChan: nil, // not used for data tries } trieSyncer, err := trie.CreateTrieSyncer(arg, u.trieSyncerVersion) if err != nil { @@ -221,33 +236,15 @@ func (u *userAccountsSyncer) updateDataTrieStatistics(trieSyncer trie.TrieSyncer } func (u *userAccountsSyncer) syncAccountDataTries( - mainTrie common.Trie, + leavesChannels *common.TrieIteratorChannels, ctx context.Context, ) error { - defer u.printDataTrieStatistics() - - mainRootHash, err := mainTrie.RootHash() - if err != nil { - return err + if leavesChannels == nil { + return trie.ErrNilTrieIteratorChannels } - leavesChannels := &common.TrieIteratorChannels{ - LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: errChan.NewErrChanWrapper(), - } - err = mainTrie.GetAllLeavesOnChannel( - leavesChannels, - context.Background(), - mainRootHash, - keyBuilder.NewDisabledKeyBuilder(), - parsers.NewMainTrieLeafParser(), - ) - if err != nil { - return err - } + defer u.printDataTrieStatistics() - var errFound error - errMutex := sync.Mutex{} wg := sync.WaitGroup{} argsAccCreation := state.ArgsAccountCreation{ Hasher: u.hasher, @@ -259,11 +256,11 @@ func (u *userAccountsSyncer) syncAccountDataTries( account, err := state.NewUserAccountFromBytes(leaf.Value(), argsAccCreation) if err != nil { - log.Trace("this must be a leaf with code", "err", err) + log.Trace("this must be a leaf with code", "leaf key", leaf.Key(), "err", err) continue } - if len(account.RootHash) == 0 { + if common.IsEmptyTrie(account.RootHash) { continue } @@ -280,11 +277,9 @@ func (u *userAccountsSyncer) syncAccountDataTries( defer u.throttler.EndProcessing() log.Trace("sync data trie", "roothash", trieRootHash) - newErr := u.syncDataTrie(trieRootHash, address, ctx) - if newErr != nil { - errMutex.Lock() - errFound = newErr - errMutex.Unlock() + err := u.syncDataTrie(trieRootHash, address, ctx) + if err != nil { + leavesChannels.ErrChan.WriteInChanNonBlocking(err) } atomic.AddInt32(&u.numTriesSynced, 1) log.Trace("finished sync data trie", "roothash", trieRootHash) @@ -294,12 +289,7 @@ func (u *userAccountsSyncer) syncAccountDataTries( wg.Wait() - err = leavesChannels.ErrChan.ReadFromChanNonBlocking() - if err != nil { - return err - } - - return errFound + return nil } func (u *userAccountsSyncer) printDataTrieStatistics() { diff --git a/state/syncer/userAccountsSyncer_test.go b/state/syncer/userAccountsSyncer_test.go new file mode 100644 index 00000000000..f6036c110c0 --- /dev/null +++ b/state/syncer/userAccountsSyncer_test.go @@ -0,0 +1,404 @@ +package syncer_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/api/mock" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/parsers" + "github.com/multiversx/mx-chain-go/state/syncer" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/testscommon/storageManager" + "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/hashesHolder" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" + "github.com/multiversx/mx-chain-go/trie/storageMarker" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func getDefaultUserAccountsSyncerArgs() syncer.ArgsNewUserAccountsSyncer { + return syncer.ArgsNewUserAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), + ShardId: 1, + Throttler: &mock.ThrottlerStub{}, + AddressPubKeyConverter: &testscommon.PubkeyConverterStub{}, + } +} + +func getDefaultArgsAccountCreation() state.ArgsAccountCreation { + return state.ArgsAccountCreation{ + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + } +} + +func TestNewUserAccountsSyncer(t *testing.T) { + t.Parallel() + + t.Run("invalid base args (nil hasher) should fail", func(t *testing.T) { + t.Parallel() + + args := getDefaultUserAccountsSyncerArgs() + args.Hasher = nil + + syncer, err := syncer.NewUserAccountsSyncer(args) + assert.Nil(t, syncer) + assert.Equal(t, state.ErrNilHasher, err) + }) + + t.Run("nil throttler", func(t *testing.T) { + t.Parallel() + + args := getDefaultUserAccountsSyncerArgs() + args.Throttler = nil + + syncer, err := syncer.NewUserAccountsSyncer(args) + assert.Nil(t, syncer) + assert.Equal(t, data.ErrNilThrottler, err) + }) + + t.Run("nil address pubkey converter", func(t *testing.T) { + t.Parallel() + + args := getDefaultUserAccountsSyncerArgs() + args.AddressPubKeyConverter = nil + + s, err := syncer.NewUserAccountsSyncer(args) + assert.Nil(t, s) + assert.Equal(t, syncer.ErrNilPubkeyConverter, err) + }) + + t.Run("invalid timeout, should fail", func(t *testing.T) { + t.Parallel() + + args := getDefaultUserAccountsSyncerArgs() + args.Timeout = 0 + + s, err := syncer.NewUserAccountsSyncer(args) + assert.Nil(t, s) + assert.True(t, errors.Is(err, common.ErrInvalidTimeout)) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + args := getDefaultUserAccountsSyncerArgs() + syncer, err := syncer.NewUserAccountsSyncer(args) + assert.Nil(t, err) + assert.NotNil(t, syncer) + }) +} + +func getSerializedTrieNode( + key []byte, + marshaller marshal.Marshalizer, + hasher hashing.Hasher, +) []byte { + var serializedLeafNode []byte + tsm := &storageManager.StorageManagerStub{ + PutCalled: func(key []byte, val []byte) error { + serializedLeafNode = val + return nil + }, + } + + tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + _ = tr.Update(key, []byte("value")) + _ = tr.Commit() + + return serializedLeafNode +} + +func TestUserAccountsSyncer_SyncAccounts(t *testing.T) { + t.Parallel() + + t.Run("nil storage marker", func(t *testing.T) { + t.Parallel() + + args := getDefaultUserAccountsSyncerArgs() + s, err := syncer.NewUserAccountsSyncer(args) + assert.Nil(t, err) + assert.NotNil(t, s) + + err = s.SyncAccounts([]byte("rootHash"), nil) + assert.Equal(t, syncer.ErrNilStorageMarker, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + args := getDefaultUserAccountsSyncerArgs() + args.Timeout = 5 * time.Second + + key := []byte("rootHash") + serializedLeafNode := getSerializedTrieNode(key, args.Marshalizer, args.Hasher) + itn, err := trie.NewInterceptedTrieNode(serializedLeafNode, args.Hasher) + require.Nil(t, err) + + args.TrieStorageManager = &storageManager.StorageManagerStub{ + GetCalled: func(b []byte) ([]byte, error) { + return serializedLeafNode, nil + }, + } + + cacher := testscommon.NewCacherMock() + cacher.Put(key, itn, 0) + args.Cacher = cacher + + s, err := syncer.NewUserAccountsSyncer(args) + require.Nil(t, err) + + err = s.SyncAccounts(key, storageMarker.NewDisabledStorageMarker()) + require.Nil(t, err) + }) +} + +func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, common.EnableEpochsHandler, uint) { + marshalizer := &testscommon.ProtobufMarshalizerMock{} + hasher := &testscommon.KeccakMock{} + + generalCfg := config.TrieStorageManagerConfig{ + PruningBufferLen: 1000, + SnapshotsBufferLen: 10, + SnapshotsGoroutineNum: 1, + } + + args := trie.NewTrieStorageManagerArgs{ + MainStorer: testscommon.NewSnapshotPruningStorerMock(), + CheckpointsStorer: testscommon.NewSnapshotPruningStorerMock(), + Marshalizer: marshalizer, + Hasher: hasher, + GeneralConfig: generalCfg, + CheckpointHashesHolder: hashesHolder.NewCheckpointHashesHolder(10000000, testscommon.HashSize), + IdleProvider: &testscommon.ProcessStatusHandlerStub{}, + Identifier: "identifier", + } + + trieStorageManager, _ := trie.NewTrieStorageManager(args) + maxTrieLevelInMemory := uint(1) + + return trieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory +} + +func emptyTrie() common.Trie { + tr, _ := trie.NewTrie(getDefaultTrieParameters()) + + return tr +} + +func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { + t.Parallel() + + t.Run("nil leaves chan should fail", func(t *testing.T) { + t.Parallel() + + args := getDefaultUserAccountsSyncerArgs() + s, err := syncer.NewUserAccountsSyncer(args) + require.Nil(t, err) + + err = s.SyncAccountDataTries(nil, context.TODO()) + require.Equal(t, trie.ErrNilTrieIteratorChannels, err) + }) + + t.Run("throttler cannot process and closed context should fail", func(t *testing.T) { + t.Parallel() + + args := getDefaultUserAccountsSyncerArgs() + args.Timeout = 5 * time.Second + + key := []byte("accRootHash") + serializedLeafNode := getSerializedTrieNode(key, args.Marshalizer, args.Hasher) + itn, err := trie.NewInterceptedTrieNode(serializedLeafNode, args.Hasher) + require.Nil(t, err) + + args.TrieStorageManager = &storageManager.StorageManagerStub{ + GetCalled: func(b []byte) ([]byte, error) { + return serializedLeafNode, nil + }, + } + args.Throttler = &mock.ThrottlerStub{ + CanProcessCalled: func() bool { + return false + }, + } + + cacher := testscommon.NewCacherMock() + cacher.Put(key, itn, 0) + args.Cacher = cacher + + s, err := syncer.NewUserAccountsSyncer(args) + require.Nil(t, err) + + _, _ = trie.NewTrie(args.TrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + tr := emptyTrie() + + account, err := state.NewUserAccount(testscommon.TestPubKeyAlice, getDefaultArgsAccountCreation()) + require.Nil(t, err) + account.SetRootHash(key) + + accountBytes, err := args.Marshalizer.Marshal(account) + require.Nil(t, err) + + _ = tr.Update([]byte("doe"), []byte("reindeer")) + _ = tr.Update([]byte("dog"), []byte("puppy")) + _ = tr.Update([]byte("ddog"), accountBytes) + _ = tr.Commit() + + leavesChannels := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), + ErrChan: errChan.NewErrChanWrapper(), + } + + rootHash, err := tr.RootHash() + require.Nil(t, err) + + err = tr.GetAllLeavesOnChannel(leavesChannels, context.TODO(), rootHash, keyBuilder.NewDisabledKeyBuilder(), parsers.NewMainTrieLeafParser()) + require.Nil(t, err) + + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + + err = s.SyncAccountDataTries(leavesChannels, ctx) + require.Equal(t, data.ErrTimeIsOut, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + args := getDefaultUserAccountsSyncerArgs() + args.Timeout = 5 * time.Second + + key := []byte("accRootHash") + serializedLeafNode := getSerializedTrieNode(key, args.Marshalizer, args.Hasher) + itn, err := trie.NewInterceptedTrieNode(serializedLeafNode, args.Hasher) + require.Nil(t, err) + + args.TrieStorageManager = &storageManager.StorageManagerStub{ + GetCalled: func(b []byte) ([]byte, error) { + return serializedLeafNode, nil + }, + } + + cacher := testscommon.NewCacherMock() + cacher.Put(key, itn, 0) + args.Cacher = cacher + + s, err := syncer.NewUserAccountsSyncer(args) + require.Nil(t, err) + + _, _ = trie.NewTrie(args.TrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + tr := emptyTrie() + + account, err := state.NewUserAccount(testscommon.TestPubKeyAlice, getDefaultArgsAccountCreation()) + require.Nil(t, err) + account.SetRootHash(key) + + accountBytes, err := args.Marshalizer.Marshal(account) + require.Nil(t, err) + + _ = tr.Update([]byte("doe"), []byte("reindeer")) + _ = tr.Update([]byte("dog"), []byte("puppy")) + _ = tr.Update([]byte("ddog"), accountBytes) + _ = tr.Commit() + + leavesChannels := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), + ErrChan: errChan.NewErrChanWrapper(), + } + + rootHash, err := tr.RootHash() + require.Nil(t, err) + + err = tr.GetAllLeavesOnChannel(leavesChannels, context.TODO(), rootHash, keyBuilder.NewDisabledKeyBuilder(), parsers.NewMainTrieLeafParser()) + require.Nil(t, err) + + err = s.SyncAccountDataTries(leavesChannels, context.TODO()) + require.Nil(t, err) + }) +} + +func TestUserAccountsSyncer_MissingDataTrieNodeFound(t *testing.T) { + t.Parallel() + + numNodesSynced := 0 + numProcessedCalled := 0 + setNumMissingCalled := 0 + args := syncer.ArgsNewUserAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), + ShardId: 0, + Throttler: &mock.ThrottlerStub{}, + AddressPubKeyConverter: &testscommon.PubkeyConverterStub{}, + } + args.TrieStorageManager = &storageManager.StorageManagerStub{ + PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { + numNodesSynced++ + return nil + }, + } + args.UserAccountsSyncStatisticsHandler = &testscommon.SizeSyncStatisticsHandlerStub{ + AddNumProcessedCalled: func(value int) { + numProcessedCalled++ + }, + SetNumMissingCalled: func(rootHash []byte, value int) { + setNumMissingCalled++ + assert.Equal(t, 0, value) + }, + } + + var serializedLeafNode []byte + tsm := &storageManager.StorageManagerStub{ + PutCalled: func(key []byte, val []byte) error { + serializedLeafNode = val + return nil + }, + } + + tr, _ := trie.NewTrie(tsm, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + key := []byte("key") + value := []byte("value") + _ = tr.Update(key, value) + rootHash, _ := tr.RootHash() + _ = tr.Commit() + + args.Cacher = &testscommon.CacherStub{ + GetCalled: func(key []byte) (value interface{}, ok bool) { + interceptedNode, _ := trie.NewInterceptedTrieNode(serializedLeafNode, args.Hasher) + return interceptedNode, true + }, + } + + syncer, _ := syncer.NewUserAccountsSyncer(args) + // test that timeout watchdog is reset + time.Sleep(args.Timeout * 2) + syncer.MissingDataTrieNodeFound(rootHash) + + assert.Equal(t, 1, numNodesSynced) + assert.Equal(t, 1, numProcessedCalled) + assert.Equal(t, 1, setNumMissingCalled) +} + +func TestUserAccountsSyncer_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var uas *syncer.UserAccountsSyncer + assert.True(t, uas.IsInterfaceNil()) + + uas, err := syncer.NewUserAccountsSyncer(getDefaultUserAccountsSyncerArgs()) + require.Nil(t, err) + assert.False(t, uas.IsInterfaceNil()) +} diff --git a/state/syncer/validatorAccountsSyncer.go b/state/syncer/validatorAccountsSyncer.go index c4893ebab26..943368441d4 100644 --- a/state/syncer/validatorAccountsSyncer.go +++ b/state/syncer/validatorAccountsSyncer.go @@ -43,7 +43,6 @@ func NewValidatorAccountsSyncer(args ArgsNewValidatorAccountsSyncer) (*validator timeoutHandler: timeoutHandler, shardId: core.MetachainShardId, cacher: args.Cacher, - rootHash: nil, maxTrieLevelInMemory: args.MaxTrieLevelInMemory, name: "peer accounts", maxHardCapForMissingNodes: args.MaxHardCapForMissingNodes, @@ -80,12 +79,17 @@ func (v *validatorAccountsSyncer) SyncAccounts(rootHash []byte, storageMarker co go v.printStatisticsAndUpdateMetrics(ctx) - mainTrie, err := v.syncMainTrie(rootHash, factory.ValidatorTrieNodesTopic, ctx) + err := v.syncMainTrie( + rootHash, + factory.ValidatorTrieNodesTopic, + ctx, + nil, // not used for validator accounts syncer + ) if err != nil { return err } - storageMarker.MarkStorerAsSyncedAndActive(mainTrie.GetStorageManager()) + storageMarker.MarkStorerAsSyncedAndActive(v.trieStorageManager) return nil } diff --git a/state/syncer/validatorAccountsSyncer_test.go b/state/syncer/validatorAccountsSyncer_test.go index 4624c550b16..b4a025883f1 100644 --- a/state/syncer/validatorAccountsSyncer_test.go +++ b/state/syncer/validatorAccountsSyncer_test.go @@ -1,24 +1,120 @@ -package syncer +package syncer_test import ( + "errors" "testing" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/syncer" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/storageManager" + "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/storageMarker" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -// TODO add more tests +func TestNewValidatorAccountsSyncer(t *testing.T) { + t.Parallel() + + t.Run("invalid base args (nil hasher) should fail", func(t *testing.T) { + t.Parallel() + + args := syncer.ArgsNewValidatorAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), + } + args.Hasher = nil + + syncer, err := syncer.NewValidatorAccountsSyncer(args) + assert.Nil(t, syncer) + assert.Equal(t, state.ErrNilHasher, err) + }) + + t.Run("invalid timeout, should fail", func(t *testing.T) { + t.Parallel() + + args := syncer.ArgsNewValidatorAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), + } + args.Timeout = 0 + + s, err := syncer.NewValidatorAccountsSyncer(args) + assert.Nil(t, s) + assert.True(t, errors.Is(err, common.ErrInvalidTimeout)) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + args := syncer.ArgsNewValidatorAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), + } + v, err := syncer.NewValidatorAccountsSyncer(args) + require.Nil(t, err) + require.NotNil(t, v) + }) +} func TestValidatorAccountsSyncer_SyncAccounts(t *testing.T) { t.Parallel() - args := ArgsNewValidatorAccountsSyncer{ - ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), - } + key := []byte("rootHash") + + t.Run("nil storage marker", func(t *testing.T) { + t.Parallel() + + args := syncer.ArgsNewValidatorAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), + } - syncer, err := NewValidatorAccountsSyncer(args) - assert.Nil(t, err) - assert.NotNil(t, syncer) + v, err := syncer.NewValidatorAccountsSyncer(args) + require.Nil(t, err) + require.NotNil(t, v) - err = syncer.SyncAccounts([]byte("rootHash"), nil) - assert.Equal(t, ErrNilStorageMarker, err) + err = v.SyncAccounts(key, nil) + require.Equal(t, syncer.ErrNilStorageMarker, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + args := syncer.ArgsNewValidatorAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), + } + + serializedLeafNode := getSerializedTrieNode(key, args.Marshalizer, args.Hasher) + itn, err := trie.NewInterceptedTrieNode(serializedLeafNode, args.Hasher) + require.Nil(t, err) + + args.TrieStorageManager = &storageManager.StorageManagerStub{ + GetCalled: func(b []byte) ([]byte, error) { + return serializedLeafNode, nil + }, + } + + cacher := testscommon.NewCacherMock() + cacher.Put(key, itn, 0) + args.Cacher = cacher + + v, err := syncer.NewValidatorAccountsSyncer(args) + require.Nil(t, err) + + err = v.SyncAccounts(key, storageMarker.NewDisabledStorageMarker()) + require.Nil(t, err) + }) +} + +func TestValidatorAccountsSyncer_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var vas *syncer.ValidatorAccountsSyncer + assert.True(t, vas.IsInterfaceNil()) + + args := syncer.ArgsNewValidatorAccountsSyncer{ + ArgsNewBaseAccountsSyncer: getDefaultBaseAccSyncerArgs(), + } + vas, err := syncer.NewValidatorAccountsSyncer(args) + require.Nil(t, err) + assert.False(t, vas.IsInterfaceNil()) } diff --git a/trie/depthFirstSync.go b/trie/depthFirstSync.go index f51388ed5dc..b2ef76ac35a 100644 --- a/trie/depthFirstSync.go +++ b/trie/depthFirstSync.go @@ -31,6 +31,7 @@ type depthFirstTrieSyncer struct { checkNodesOnDisk bool nodes *trieNodesHandler requestedHashes map[string]*request + leavesChan chan core.KeyValueHolder } // NewDepthFirstTrieSyncer creates a new instance of trieSyncer that uses the depth-first algorithm @@ -58,6 +59,7 @@ func NewDepthFirstTrieSyncer(arg ArgTrieSyncer) (*depthFirstTrieSyncer, error) { timeoutHandler: arg.TimeoutHandler, maxHardCapForMissingNodes: arg.MaxHardCapForMissingNodes, checkNodesOnDisk: arg.CheckNodesOnDisk, + leavesChan: arg.LeavesChan, } return d, nil @@ -251,6 +253,8 @@ func (d *depthFirstTrieSyncer) storeTrieNode(element node) error { d.trieSyncStatistics.AddNumBytesReceived(uint64(numBytes)) d.updateStats(uint64(numBytes), element) + writeLeafNodeToChan(element, d.leavesChan) + return nil } diff --git a/trie/depthFirstSync_test.go b/trie/depthFirstSync_test.go index 95ea4db62aa..456c1b1f3e8 100644 --- a/trie/depthFirstSync_test.go +++ b/trie/depthFirstSync_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "sync" "testing" "time" @@ -108,6 +109,7 @@ func TestDepthFirstTrieSyncer_StartSyncingNewTrieShouldWork(t *testing.T) { arg := createMockArgument(time.Minute) arg.RequestHandler = createRequesterResolver(trSource, arg.InterceptedNodes, nil) + arg.LeavesChan = make(chan core.KeyValueHolder, 110) d, _ := NewDepthFirstTrieSyncer(arg) ctx, cancelFunc := context.WithTimeout(context.Background(), time.Second*30) @@ -134,6 +136,22 @@ func TestDepthFirstTrieSyncer_StartSyncingNewTrieShouldWork(t *testing.T) { assert.True(t, d.NumTrieNodes() > d.NumLeaves()) assert.True(t, d.NumBytes() > 0) assert.True(t, d.Duration() > 0) + + wg := &sync.WaitGroup{} + wg.Add(numKeysValues) + + numLeavesOnChan := 0 + go func() { + for range arg.LeavesChan { + numLeavesOnChan++ + wg.Done() + } + }() + + wg.Wait() + + assert.Equal(t, numKeysValues, numLeavesOnChan) + log.Info("synced trie", "num trie nodes", d.NumTrieNodes(), "num leaves", d.NumLeaves(), diff --git a/trie/doubleListSync.go b/trie/doubleListSync.go index b5f68c7bf0d..edf7ed76d23 100644 --- a/trie/doubleListSync.go +++ b/trie/doubleListSync.go @@ -43,6 +43,7 @@ type doubleListTrieSyncer struct { existingNodes map[string]node missingHashes map[string]struct{} requestedHashes map[string]*request + leavesChan chan core.KeyValueHolder } // NewDoubleListTrieSyncer creates a new instance of trieSyncer that uses 2 list for keeping the "margin" nodes. @@ -73,6 +74,7 @@ func NewDoubleListTrieSyncer(arg ArgTrieSyncer) (*doubleListTrieSyncer, error) { timeoutHandler: arg.TimeoutHandler, maxHardCapForMissingNodes: arg.MaxHardCapForMissingNodes, checkNodesOnDisk: arg.CheckNodesOnDisk, + leavesChan: arg.LeavesChan, } return d, nil @@ -207,6 +209,8 @@ func (d *doubleListTrieSyncer) processExistingNodes() error { return err } + writeLeafNodeToChan(element, d.leavesChan) + d.timeoutHandler.ResetWatchdog() var children []node diff --git a/trie/doubleListSync_test.go b/trie/doubleListSync_test.go index c4765287c7c..65197f171fc 100644 --- a/trie/doubleListSync_test.go +++ b/trie/doubleListSync_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "sync" "testing" "time" @@ -227,6 +228,22 @@ func TestDoubleListTrieSyncer_StartSyncingNewTrieShouldWork(t *testing.T) { assert.True(t, d.NumTrieNodes() > d.NumLeaves()) assert.True(t, d.NumBytes() > 0) assert.True(t, d.Duration() > 0) + + wg := &sync.WaitGroup{} + wg.Add(numKeysValues) + + numLeavesOnChan := 0 + go func() { + for range arg.LeavesChan { + numLeavesOnChan++ + wg.Done() + } + }() + + wg.Wait() + + assert.Equal(t, numKeysValues, numLeavesOnChan) + log.Info("synced trie", "num trie nodes", d.NumTrieNodes(), "num leaves", d.NumLeaves(), diff --git a/trie/sync.go b/trie/sync.go index 89c8d3ef3c5..ce48f8c8e6b 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/core/keyValStorage" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" @@ -40,6 +41,7 @@ type trieSyncer struct { trieSyncStatistics data.SyncStatisticsHandler timeoutHandler TimeoutHandler maxHardCapForMissingNodes int + leavesChan chan core.KeyValueHolder } const maxNewMissingAddedPerTurn = 10 @@ -57,6 +59,7 @@ type ArgTrieSyncer struct { MaxHardCapForMissingNodes int CheckNodesOnDisk bool TimeoutHandler TimeoutHandler + LeavesChan chan core.KeyValueHolder } // NewTrieSyncer creates a new instance of trieSyncer @@ -85,6 +88,7 @@ func NewTrieSyncer(arg ArgTrieSyncer) (*trieSyncer, error) { trieSyncStatistics: arg.TrieSyncStatistics, timeoutHandler: arg.TimeoutHandler, maxHardCapForMissingNodes: arg.MaxHardCapForMissingNodes, + leavesChan: arg.LeavesChan, } return ts, nil @@ -244,6 +248,9 @@ func (ts *trieSyncer) checkIfSynced() (bool, error) { if err != nil { return false, err } + + writeLeafNodeToChan(currentNode, ts.leavesChan) + ts.timeoutHandler.ResetWatchdog() ts.updateStats(uint64(numBytes), currentNode) @@ -363,6 +370,20 @@ func trieNode( return decodedNode, nil } +func writeLeafNodeToChan(element node, ch chan core.KeyValueHolder) { + if ch == nil { + return + } + + leafNodeElement, isLeaf := element.(*leafNode) + if !isLeaf { + return + } + + trieLeaf := keyValStorage.NewKeyValStorage(leafNodeElement.Key, leafNodeElement.Value) + ch <- trieLeaf +} + func (ts *trieSyncer) requestNodes() uint32 { ts.mutOperation.RLock() numUnResolvedNodes := uint32(len(ts.nodesForTrie)) diff --git a/trie/sync_test.go b/trie/sync_test.go index 33a49abee76..ab5083eb85a 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -3,6 +3,7 @@ package trie import ( "context" "errors" + "sync" "testing" "time" @@ -40,116 +41,121 @@ func createMockArgument(timeout time.Duration) ArgTrieSyncer { TrieSyncStatistics: statistics.NewTrieSyncStatistics(), TimeoutHandler: testscommon.NewTimeoutHandlerMock(timeout), MaxHardCapForMissingNodes: 500, + LeavesChan: make(chan core.KeyValueHolder, 100), } } -func TestNewTrieSyncer_NilRequestHandlerShouldErr(t *testing.T) { +func TestNewTrieSyncer(t *testing.T) { t.Parallel() - arg := createMockArgument(time.Minute) - arg.RequestHandler = nil + t.Run("nil request handler", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.Equal(t, err, ErrNilRequestHandler) -} + arg := createMockArgument(time.Minute) + arg.RequestHandler = nil -func TestNewTrieSyncer_NilInterceptedNodesShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.Equal(t, err, ErrNilRequestHandler) + }) - arg := createMockArgument(time.Minute) - arg.InterceptedNodes = nil + t.Run("nil intercepted nodes", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.Equal(t, err, data.ErrNilCacher) -} + arg := createMockArgument(time.Minute) + arg.InterceptedNodes = nil -func TestNewTrieSyncer_EmptyTopicShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.Equal(t, err, data.ErrNilCacher) + }) - arg := createMockArgument(time.Minute) - arg.Topic = "" + t.Run("empty topic should fail", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.Equal(t, err, ErrInvalidTrieTopic) -} + arg := createMockArgument(time.Minute) + arg.Topic = "" -func TestNewTrieSyncer_NilTrieStatisticsShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.Equal(t, err, ErrInvalidTrieTopic) + }) - arg := createMockArgument(time.Minute) - arg.TrieSyncStatistics = nil + t.Run("nil trie statistics", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.Equal(t, err, ErrNilTrieSyncStatistics) -} + arg := createMockArgument(time.Minute) + arg.TrieSyncStatistics = nil -func TestNewTrieSyncer_NilDatabaseShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.Equal(t, err, ErrNilTrieSyncStatistics) + }) - arg := createMockArgument(time.Minute) - arg.DB = nil + t.Run("nil database", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.True(t, errors.Is(err, ErrNilDatabase)) -} + arg := createMockArgument(time.Minute) + arg.DB = nil -func TestNewTrieSyncer_NilMarshalizerShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.True(t, errors.Is(err, ErrNilDatabase)) + }) - arg := createMockArgument(time.Minute) - arg.Marshalizer = nil + t.Run("nil marshalizer", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.True(t, errors.Is(err, ErrNilMarshalizer)) -} + arg := createMockArgument(time.Minute) + arg.Marshalizer = nil -func TestNewTrieSyncer_NilHasherShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.True(t, errors.Is(err, ErrNilMarshalizer)) + }) - arg := createMockArgument(time.Minute) - arg.Hasher = nil + t.Run("nil hasher", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.True(t, errors.Is(err, ErrNilHasher)) -} + arg := createMockArgument(time.Minute) + arg.Hasher = nil -func TestNewTrieSyncer_NilTimeoutHandlerShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.True(t, errors.Is(err, ErrNilHasher)) + }) - arg := createMockArgument(time.Minute) - arg.TimeoutHandler = nil + t.Run("nil timeout handler", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.True(t, errors.Is(err, ErrNilTimeoutHandler)) -} + arg := createMockArgument(time.Minute) + arg.TimeoutHandler = nil -func TestNewTrieSyncer_InvalidMaxHardCapForMissingNodesShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.True(t, errors.Is(err, ErrNilTimeoutHandler)) + }) - arg := createMockArgument(time.Minute) - arg.MaxHardCapForMissingNodes = 0 + t.Run("invalid max hard capacity for missing nodes", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.True(t, errors.Is(err, ErrInvalidMaxHardCapForMissingNodes)) -} + arg := createMockArgument(time.Minute) + arg.MaxHardCapForMissingNodes = 0 -func TestNewTrieSyncer_ShouldWork(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.True(t, errors.Is(err, ErrInvalidMaxHardCapForMissingNodes)) + }) - arg := createMockArgument(time.Minute) + t.Run("should work", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.False(t, check.IfNil(ts)) - assert.Nil(t, err) + arg := createMockArgument(time.Minute) + + ts, err := NewTrieSyncer(arg) + assert.False(t, check.IfNil(ts)) + assert.Nil(t, err) + }) } func TestTrieSync_InterceptedNodeShouldNotBeAddedToNodesForTrieIfNodeReceived(t *testing.T) { @@ -220,6 +226,10 @@ func TestTrieSync_FoundInStorageShouldNotRequest(t *testing.T) { err = bn.commitSnapshot(db, nil, nil, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, 0) require.Nil(t, err) + leaves, err := bn.getChildren(db) + require.Nil(t, err) + numLeaves := len(leaves) + arg := createMockArgument(timeout) arg.RequestHandler = &testscommon.RequestHandlerStub{ RequestTrieNodesCalled: func(destShardID uint32, hashes [][]byte, topic string) { @@ -235,6 +245,21 @@ func TestTrieSync_FoundInStorageShouldNotRequest(t *testing.T) { err = ts.StartSyncing(rootHash, context.Background()) assert.Nil(t, err) + + wg := &sync.WaitGroup{} + wg.Add(numLeaves) + + numLeavesOnChan := 0 + go func() { + for range arg.LeavesChan { + numLeavesOnChan++ + wg.Done() + } + }() + + wg.Wait() + + assert.Equal(t, numLeaves, numLeavesOnChan) } func TestTrieSync_StartSyncing(t *testing.T) { diff --git a/trie/trieStorageManager.go b/trie/trieStorageManager.go index 24f770ac57c..a8963058169 100644 --- a/trie/trieStorageManager.go +++ b/trie/trieStorageManager.go @@ -335,19 +335,19 @@ func (tsm *trieStorageManager) TakeSnapshot( ) { if iteratorChannels.ErrChan == nil { log.Error("programming error in trieStorageManager.TakeSnapshot, cannot take snapshot because errChan is nil") - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } if tsm.IsClosed() { - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } if bytes.Equal(rootHash, common.EmptyTrieHash) { log.Trace("should not snapshot an empty trie") - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } @@ -368,7 +368,7 @@ func (tsm *trieStorageManager) TakeSnapshot( case tsm.snapshotReq <- snapshotEntry: case <-tsm.closer.ChanClose(): tsm.ExitPruningBufferingMode() - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() } } @@ -385,19 +385,19 @@ func (tsm *trieStorageManager) SetCheckpoint( ) { if iteratorChannels.ErrChan == nil { log.Error("programming error in trieStorageManager.SetCheckpoint, cannot set checkpoint because errChan is nil") - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } if tsm.IsClosed() { - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } if bytes.Equal(rootHash, common.EmptyTrieHash) { log.Trace("should not set checkpoint for empty trie") - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } @@ -415,21 +415,15 @@ func (tsm *trieStorageManager) SetCheckpoint( case tsm.checkpointReq <- checkpointEntry: case <-tsm.closer.ChanClose(): tsm.ExitPruningBufferingMode() - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() } } -func safelyCloseChan(ch chan core.KeyValueHolder) { - if ch != nil { - close(ch) - } -} - func (tsm *trieStorageManager) finishOperation(snapshotEntry *snapshotsQueueEntry, message string) { tsm.ExitPruningBufferingMode() log.Trace(message, "rootHash", snapshotEntry.rootHash) - safelyCloseChan(snapshotEntry.iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(snapshotEntry.iteratorChannels.LeavesChan) snapshotEntry.stats.SnapshotFinished() } diff --git a/trie/trieStorageManagerWithoutCheckpoints.go b/trie/trieStorageManagerWithoutCheckpoints.go index d2f4b93e507..975a9a10111 100644 --- a/trie/trieStorageManagerWithoutCheckpoints.go +++ b/trie/trieStorageManagerWithoutCheckpoints.go @@ -30,7 +30,7 @@ func (tsm *trieStorageManagerWithoutCheckpoints) SetCheckpoint( stats common.SnapshotStatisticsHandler, ) { if iteratorChannels != nil { - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) } stats.SnapshotFinished() diff --git a/trie/trieStorageManagerWithoutSnapshot.go b/trie/trieStorageManagerWithoutSnapshot.go index 337a74f8f9d..7e538eaf184 100644 --- a/trie/trieStorageManagerWithoutSnapshot.go +++ b/trie/trieStorageManagerWithoutSnapshot.go @@ -38,7 +38,7 @@ func (tsm *trieStorageManagerWithoutSnapshot) PutInEpochWithoutCache(key []byte, // TakeSnapshot does nothing, as snapshots are disabled for this implementation func (tsm *trieStorageManagerWithoutSnapshot) TakeSnapshot(_ string, _ []byte, _ []byte, iteratorChannels *common.TrieIteratorChannels, _ chan []byte, stats common.SnapshotStatisticsHandler, _ uint32) { if iteratorChannels != nil { - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) } stats.SnapshotFinished() }