diff --git a/common/channels.go b/common/channels.go index 177ac89f5c5..ba240d76b7b 100644 --- a/common/channels.go +++ b/common/channels.go @@ -1,5 +1,7 @@ package common +import "github.com/multiversx/mx-chain-core-go/core" + // GetClosedUnbufferedChannel returns an instance of a 'chan struct{}' that is already closed func GetClosedUnbufferedChannel() chan struct{} { ch := make(chan struct{}) @@ -7,3 +9,10 @@ func GetClosedUnbufferedChannel() chan struct{} { return ch } + +// CloseKeyValueHolderChan will close the channel if not nil +func CloseKeyValueHolderChan(ch chan core.KeyValueHolder) { + if ch != nil { + close(ch) + } +} diff --git a/common/constants.go b/common/constants.go index 7dc897076e9..66cbd149718 100644 --- a/common/constants.go +++ b/common/constants.go @@ -814,6 +814,10 @@ const ( // TrieLeavesChannelDefaultCapacity represents the default value to be used as capacity for getting all trie leaves on // a channel TrieLeavesChannelDefaultCapacity = 100 + + // TrieLeavesChannelSyncCapacity represents the value to be used as capacity for getting main trie + // leaf nodes for trie sync + TrieLeavesChannelSyncCapacity = 1000 ) // ApiOutputFormat represents the format type returned by api diff --git a/state/syncer/baseAccountsSyncer.go b/state/syncer/baseAccountsSyncer.go index 18d28fc3370..af0ef1fb456 100644 --- a/state/syncer/baseAccountsSyncer.go +++ b/state/syncer/baseAccountsSyncer.go @@ -27,7 +27,6 @@ type baseAccountsSyncer struct { timeoutHandler trie.TimeoutHandler shardId uint32 cacher storage.Cacher - rootHash []byte maxTrieLevelInMemory uint name string maxHardCapForMissingNodes int @@ -93,15 +92,11 @@ func (b *baseAccountsSyncer) syncMainTrie( rootHash []byte, trieTopic string, ctx context.Context, -) (common.Trie, error) { - b.rootHash = rootHash + leavesChan chan core.KeyValueHolder, +) error { atomic.AddInt32(&b.numMaxTries, 1) log.Trace("syncing main trie", "roothash", rootHash) - dataTrie, err := trie.NewTrie(b.trieStorageManager, b.marshalizer, b.hasher, b.maxTrieLevelInMemory) - if err != nil { - return nil, err - } b.dataTries[string(rootHash)] = struct{}{} arg := trie.ArgTrieSyncer{ @@ -116,22 +111,23 @@ func (b *baseAccountsSyncer) syncMainTrie( TimeoutHandler: b.timeoutHandler, MaxHardCapForMissingNodes: b.maxHardCapForMissingNodes, CheckNodesOnDisk: b.checkNodesOnDisk, + LeavesChan: leavesChan, } trieSyncer, err := trie.CreateTrieSyncer(arg, b.trieSyncerVersion) if err != nil { - return nil, err + return err } err = trieSyncer.StartSyncing(rootHash, ctx) if err != nil { - return nil, err + return err } atomic.AddInt32(&b.numTriesSynced, 1) log.Trace("finished syncing main trie", "roothash", rootHash) - return dataTrie.Recreate(rootHash) + return nil } func (b *baseAccountsSyncer) printStatisticsAndUpdateMetrics(ctx context.Context) { diff --git a/state/syncer/export_test.go b/state/syncer/export_test.go index dcea224d65e..e8fade258ac 100644 --- a/state/syncer/export_test.go +++ b/state/syncer/export_test.go @@ -19,8 +19,8 @@ func CheckBaseAccountsSyncerArgs(args ArgsNewBaseAccountsSyncer) error { // SyncAccountDataTries - func (u *userAccountsSyncer) SyncAccountDataTries( - mainTrie common.Trie, + leavesChannels *common.TrieIteratorChannels, ctx context.Context, ) error { - return u.syncAccountDataTries(mainTrie, ctx) + return u.syncAccountDataTries(leavesChannels, ctx) } diff --git a/state/syncer/userAccountsSyncer.go b/state/syncer/userAccountsSyncer.go index 5f992eef9d9..fec67fe8fc6 100644 --- a/state/syncer/userAccountsSyncer.go +++ b/state/syncer/userAccountsSyncer.go @@ -17,7 +17,6 @@ import ( "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/trie" - "github.com/multiversx/mx-chain-go/trie/keyBuilder" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -83,7 +82,6 @@ func NewUserAccountsSyncer(args ArgsNewUserAccountsSyncer) (*userAccountsSyncer, timeoutHandler: timeoutHandler, shardId: args.ShardId, cacher: args.Cacher, - rootHash: nil, maxTrieLevelInMemory: args.MaxTrieLevelInMemory, name: fmt.Sprintf("user accounts for shard %s", core.GetShardIDString(args.ShardId)), maxHardCapForMissingNodes: args.MaxHardCapForMissingNodes, @@ -119,23 +117,40 @@ func (u *userAccountsSyncer) SyncAccounts(rootHash []byte) error { go u.printStatisticsAndUpdateMetrics(ctx) - mainTrie, err := u.syncMainTrie(rootHash, factory.AccountTrieNodesTopic, ctx) - if err != nil { - return err + leavesChannels := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelSyncCapacity), + ErrChan: errChan.NewErrChanWrapper(), } - defer func() { - _ = mainTrie.Close() + wgSyncMainTrie := &sync.WaitGroup{} + wgSyncMainTrie.Add(1) + + go func() { + err := u.syncMainTrie(rootHash, factory.AccountTrieNodesTopic, ctx, leavesChannels.LeavesChan) + if err != nil { + leavesChannels.ErrChan.WriteInChanNonBlocking(err) + } + + common.CloseKeyValueHolderChan(leavesChannels.LeavesChan) + + wgSyncMainTrie.Done() }() - log.Debug("main trie synced, starting to sync data tries", "num data tries", len(u.dataTries)) + err := u.syncAccountDataTries(leavesChannels, ctx) + if err != nil { + return err + } + + wgSyncMainTrie.Wait() - err = u.syncAccountDataTries(mainTrie, ctx) + err = leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return err } - u.storageMarker.MarkStorerAsSyncedAndActive(mainTrie.GetStorageManager()) + u.storageMarker.MarkStorerAsSyncedAndActive(u.trieStorageManager) + + log.Debug("main trie and data tries synced", "main trie root hash", rootHash, "num data tries", len(u.dataTries)) return nil } @@ -163,6 +178,7 @@ func (u *userAccountsSyncer) syncDataTrie(rootHash []byte, address []byte, ctx c TimeoutHandler: u.timeoutHandler, MaxHardCapForMissingNodes: u.maxHardCapForMissingNodes, CheckNodesOnDisk: u.checkNodesOnDisk, + LeavesChan: nil, // not used for data tries } trieSyncer, err := trie.CreateTrieSyncer(arg, u.trieSyncerVersion) if err != nil { @@ -202,40 +218,28 @@ func (u *userAccountsSyncer) updateDataTrieStatistics(trieSyncer trie.TrieSyncer } func (u *userAccountsSyncer) syncAccountDataTries( - mainTrie common.Trie, + leavesChannels *common.TrieIteratorChannels, ctx context.Context, ) error { - defer u.printDataTrieStatistics() - - mainRootHash, err := mainTrie.RootHash() - if err != nil { - return err + if leavesChannels == nil { + return trie.ErrNilTrieIteratorChannels } - leavesChannels := &common.TrieIteratorChannels{ - LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: errChan.NewErrChanWrapper(), - } - err = mainTrie.GetAllLeavesOnChannel(leavesChannels, context.Background(), mainRootHash, keyBuilder.NewDisabledKeyBuilder()) - if err != nil { - return err - } + defer u.printDataTrieStatistics() - var errFound error - errMutex := sync.Mutex{} wg := sync.WaitGroup{} for leaf := range leavesChannels.LeavesChan { u.resetTimeoutHandlerWatchdog() account := state.NewEmptyUserAccount() - err = u.marshalizer.Unmarshal(account, leaf.Value()) + err := u.marshalizer.Unmarshal(account, leaf.Value()) if err != nil { - log.Trace("this must be a leaf with code", "err", err) + log.Trace("this must be a leaf with code", "leaf key", leaf.Key(), "err", err) continue } - if len(account.RootHash) == 0 { + if common.IsEmptyTrie(account.RootHash) { continue } @@ -252,11 +256,9 @@ func (u *userAccountsSyncer) syncAccountDataTries( defer u.throttler.EndProcessing() log.Trace("sync data trie", "roothash", trieRootHash) - newErr := u.syncDataTrie(trieRootHash, address, ctx) - if newErr != nil { - errMutex.Lock() - errFound = newErr - errMutex.Unlock() + err := u.syncDataTrie(trieRootHash, address, ctx) + if err != nil { + leavesChannels.ErrChan.WriteInChanNonBlocking(err) } atomic.AddInt32(&u.numTriesSynced, 1) log.Trace("finished sync data trie", "roothash", trieRootHash) @@ -266,12 +268,7 @@ func (u *userAccountsSyncer) syncAccountDataTries( wg.Wait() - err = leavesChannels.ErrChan.ReadFromChanNonBlocking() - if err != nil { - return err - } - - return errFound + return nil } func (u *userAccountsSyncer) printDataTrieStatistics() { diff --git a/state/syncer/userAccountsSyncer_test.go b/state/syncer/userAccountsSyncer_test.go index f0080682107..51184d76d91 100644 --- a/state/syncer/userAccountsSyncer_test.go +++ b/state/syncer/userAccountsSyncer_test.go @@ -6,18 +6,20 @@ import ( "testing" "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/api/mock" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/testscommon" - trieMocks "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/hashesHolder" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -171,43 +173,15 @@ func emptyTrie() common.Trie { func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { t.Parallel() - t.Run("failed to get trie root hash", func(t *testing.T) { + t.Run("nil leaves chan should fail", func(t *testing.T) { t.Parallel() - expectedErr := errors.New("expected err") - tr := &trieMocks.TrieStub{ - RootCalled: func() ([]byte, error) { - return nil, expectedErr - }, - } - args := getDefaultUserAccountsSyncerArgs() s, err := syncer.NewUserAccountsSyncer(args) require.Nil(t, err) - err = s.SyncAccountDataTries(tr, context.TODO()) - require.Equal(t, expectedErr, err) - }) - - t.Run("failed to get all leaves on channel", func(t *testing.T) { - t.Parallel() - - expectedErr := errors.New("expected err") - tr := &trieMocks.TrieStub{ - RootCalled: func() ([]byte, error) { - return []byte("rootHash"), nil - }, - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { - return expectedErr - }, - } - - args := getDefaultUserAccountsSyncerArgs() - s, err := syncer.NewUserAccountsSyncer(args) - require.Nil(t, err) - - err = s.SyncAccountDataTries(tr, context.TODO()) - require.Equal(t, expectedErr, err) + err = s.SyncAccountDataTries(nil, context.TODO()) + require.Equal(t, trie.ErrNilTrieIteratorChannels, err) }) t.Run("throttler cannot process and closed context should fail", func(t *testing.T) { @@ -254,10 +228,21 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { _ = tr.Update([]byte("ddog"), accountBytes) _ = tr.Commit() + leavesChannels := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), + ErrChan: errChan.NewErrChanWrapper(), + } + + rootHash, err := tr.RootHash() + require.Nil(t, err) + + err = tr.GetAllLeavesOnChannel(leavesChannels, context.TODO(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + require.Nil(t, err) + ctx, cancel := context.WithCancel(context.TODO()) cancel() - err = s.SyncAccountDataTries(tr, ctx) + err = s.SyncAccountDataTries(leavesChannels, ctx) require.Equal(t, data.ErrTimeIsOut, err) }) @@ -300,7 +285,18 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { _ = tr.Update([]byte("ddog"), accountBytes) _ = tr.Commit() - err = s.SyncAccountDataTries(tr, context.TODO()) + leavesChannels := &common.TrieIteratorChannels{ + LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), + ErrChan: errChan.NewErrChanWrapper(), + } + + rootHash, err := tr.RootHash() + require.Nil(t, err) + + err = tr.GetAllLeavesOnChannel(leavesChannels, context.TODO(), rootHash, keyBuilder.NewDisabledKeyBuilder()) + require.Nil(t, err) + + err = s.SyncAccountDataTries(leavesChannels, context.TODO()) require.Nil(t, err) }) } diff --git a/state/syncer/validatorAccountsSyncer.go b/state/syncer/validatorAccountsSyncer.go index 34b87d1eb78..856d3ddc2cc 100644 --- a/state/syncer/validatorAccountsSyncer.go +++ b/state/syncer/validatorAccountsSyncer.go @@ -42,7 +42,6 @@ func NewValidatorAccountsSyncer(args ArgsNewValidatorAccountsSyncer) (*validator timeoutHandler: timeoutHandler, shardId: core.MetachainShardId, cacher: args.Cacher, - rootHash: nil, maxTrieLevelInMemory: args.MaxTrieLevelInMemory, name: "peer accounts", maxHardCapForMissingNodes: args.MaxHardCapForMissingNodes, @@ -75,12 +74,17 @@ func (v *validatorAccountsSyncer) SyncAccounts(rootHash []byte) error { go v.printStatisticsAndUpdateMetrics(ctx) - mainTrie, err := v.syncMainTrie(rootHash, factory.ValidatorTrieNodesTopic, ctx) + err := v.syncMainTrie( + rootHash, + factory.ValidatorTrieNodesTopic, + ctx, + nil, // not used for validator accounts syncer + ) if err != nil { return err } - v.storageMarker.MarkStorerAsSyncedAndActive(mainTrie.GetStorageManager()) + v.storageMarker.MarkStorerAsSyncedAndActive(v.trieStorageManager) return nil } diff --git a/trie/depthFirstSync.go b/trie/depthFirstSync.go index 5f2d088fc7d..2af9bbb5e72 100644 --- a/trie/depthFirstSync.go +++ b/trie/depthFirstSync.go @@ -32,6 +32,7 @@ type depthFirstTrieSyncer struct { checkNodesOnDisk bool nodes *trieNodesHandler requestedHashes map[string]*request + leavesChan chan core.KeyValueHolder } // NewDepthFirstTrieSyncer creates a new instance of trieSyncer that uses the depth-first algorithm @@ -59,6 +60,7 @@ func NewDepthFirstTrieSyncer(arg ArgTrieSyncer) (*depthFirstTrieSyncer, error) { timeoutHandler: arg.TimeoutHandler, maxHardCapForMissingNodes: arg.MaxHardCapForMissingNodes, checkNodesOnDisk: arg.CheckNodesOnDisk, + leavesChan: arg.LeavesChan, } return d, nil @@ -252,6 +254,8 @@ func (d *depthFirstTrieSyncer) storeTrieNode(element node) error { d.trieSyncStatistics.AddNumBytesReceived(uint64(numBytes)) d.updateStats(uint64(numBytes), element) + writeLeafNodeToChan(element, d.leavesChan) + return nil } diff --git a/trie/depthFirstSync_test.go b/trie/depthFirstSync_test.go index 6ace7fbdb3f..4fc6d9194aa 100644 --- a/trie/depthFirstSync_test.go +++ b/trie/depthFirstSync_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "sync" "testing" "time" @@ -109,6 +110,7 @@ func TestDepthFirstTrieSyncer_StartSyncingNewTrieShouldWork(t *testing.T) { arg := createMockArgument(time.Minute) arg.RequestHandler = createRequesterResolver(trSource, arg.InterceptedNodes, nil) + arg.LeavesChan = make(chan core.KeyValueHolder, 110) d, _ := NewDepthFirstTrieSyncer(arg) ctx, cancelFunc := context.WithTimeout(context.Background(), time.Second*30) @@ -135,6 +137,22 @@ func TestDepthFirstTrieSyncer_StartSyncingNewTrieShouldWork(t *testing.T) { assert.True(t, d.NumTrieNodes() > d.NumLeaves()) assert.True(t, d.NumBytes() > 0) assert.True(t, d.Duration() > 0) + + wg := &sync.WaitGroup{} + wg.Add(numKeysValues) + + numLeavesOnChan := 0 + go func() { + for range arg.LeavesChan { + numLeavesOnChan++ + wg.Done() + } + }() + + wg.Wait() + + assert.Equal(t, numKeysValues, numLeavesOnChan) + log.Info("synced trie", "num trie nodes", d.NumTrieNodes(), "num leaves", d.NumLeaves(), diff --git a/trie/doubleListSync.go b/trie/doubleListSync.go index 6477023c7d2..cfd7120e7f8 100644 --- a/trie/doubleListSync.go +++ b/trie/doubleListSync.go @@ -44,6 +44,7 @@ type doubleListTrieSyncer struct { existingNodes map[string]node missingHashes map[string]struct{} requestedHashes map[string]*request + leavesChan chan core.KeyValueHolder } // NewDoubleListTrieSyncer creates a new instance of trieSyncer that uses 2 list for keeping the "margin" nodes. @@ -74,6 +75,7 @@ func NewDoubleListTrieSyncer(arg ArgTrieSyncer) (*doubleListTrieSyncer, error) { timeoutHandler: arg.TimeoutHandler, maxHardCapForMissingNodes: arg.MaxHardCapForMissingNodes, checkNodesOnDisk: arg.CheckNodesOnDisk, + leavesChan: arg.LeavesChan, } return d, nil @@ -208,6 +210,8 @@ func (d *doubleListTrieSyncer) processExistingNodes() error { return err } + writeLeafNodeToChan(element, d.leavesChan) + d.timeoutHandler.ResetWatchdog() var children []node diff --git a/trie/doubleListSync_test.go b/trie/doubleListSync_test.go index 719d578e5c6..a519db35d2e 100644 --- a/trie/doubleListSync_test.go +++ b/trie/doubleListSync_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "sync" "testing" "time" @@ -240,6 +241,22 @@ func TestDoubleListTrieSyncer_StartSyncingNewTrieShouldWork(t *testing.T) { assert.True(t, d.NumTrieNodes() > d.NumLeaves()) assert.True(t, d.NumBytes() > 0) assert.True(t, d.Duration() > 0) + + wg := &sync.WaitGroup{} + wg.Add(numKeysValues) + + numLeavesOnChan := 0 + go func() { + for range arg.LeavesChan { + numLeavesOnChan++ + wg.Done() + } + }() + + wg.Wait() + + assert.Equal(t, numKeysValues, numLeavesOnChan) + log.Info("synced trie", "num trie nodes", d.NumTrieNodes(), "num leaves", d.NumLeaves(), diff --git a/trie/sync.go b/trie/sync.go index 465ebf71a99..5acd55c6b44 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/core/keyValStorage" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" @@ -41,6 +42,7 @@ type trieSyncer struct { trieSyncStatistics data.SyncStatisticsHandler timeoutHandler TimeoutHandler maxHardCapForMissingNodes int + leavesChan chan core.KeyValueHolder } const maxNewMissingAddedPerTurn = 10 @@ -58,6 +60,7 @@ type ArgTrieSyncer struct { MaxHardCapForMissingNodes int CheckNodesOnDisk bool TimeoutHandler TimeoutHandler + LeavesChan chan core.KeyValueHolder } // NewTrieSyncer creates a new instance of trieSyncer @@ -86,6 +89,7 @@ func NewTrieSyncer(arg ArgTrieSyncer) (*trieSyncer, error) { trieSyncStatistics: arg.TrieSyncStatistics, timeoutHandler: arg.TimeoutHandler, maxHardCapForMissingNodes: arg.MaxHardCapForMissingNodes, + leavesChan: arg.LeavesChan, } return ts, nil @@ -245,6 +249,9 @@ func (ts *trieSyncer) checkIfSynced() (bool, error) { if err != nil { return false, err } + + writeLeafNodeToChan(currentNode, ts.leavesChan) + ts.timeoutHandler.ResetWatchdog() ts.updateStats(uint64(numBytes), currentNode) @@ -364,6 +371,20 @@ func trieNode( return decodedNode, nil } +func writeLeafNodeToChan(element node, ch chan core.KeyValueHolder) { + if ch == nil { + return + } + + leafNodeElement, isLeaf := element.(*leafNode) + if !isLeaf { + return + } + + trieLeaf := keyValStorage.NewKeyValStorage(leafNodeElement.Key, leafNodeElement.Value) + ch <- trieLeaf +} + func (ts *trieSyncer) requestNodes() uint32 { ts.mutOperation.RLock() numUnResolvedNodes := uint32(len(ts.nodesForTrie)) diff --git a/trie/sync_test.go b/trie/sync_test.go index cf56628be2c..3b783f90c11 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -3,9 +3,11 @@ package trie import ( "context" "errors" + "sync" "testing" "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-go/testscommon" @@ -37,116 +39,121 @@ func createMockArgument(timeout time.Duration) ArgTrieSyncer { TrieSyncStatistics: statistics.NewTrieSyncStatistics(), TimeoutHandler: testscommon.NewTimeoutHandlerMock(timeout), MaxHardCapForMissingNodes: 500, + LeavesChan: make(chan core.KeyValueHolder, 100), } } -func TestNewTrieSyncer_NilRequestHandlerShouldErr(t *testing.T) { +func TestNewTrieSyncer(t *testing.T) { t.Parallel() - arg := createMockArgument(time.Minute) - arg.RequestHandler = nil + t.Run("nil request handler", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.Equal(t, err, ErrNilRequestHandler) -} + arg := createMockArgument(time.Minute) + arg.RequestHandler = nil -func TestNewTrieSyncer_NilInterceptedNodesShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.Equal(t, err, ErrNilRequestHandler) + }) - arg := createMockArgument(time.Minute) - arg.InterceptedNodes = nil + t.Run("nil intercepted nodes", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.Equal(t, err, data.ErrNilCacher) -} + arg := createMockArgument(time.Minute) + arg.InterceptedNodes = nil -func TestNewTrieSyncer_EmptyTopicShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.Equal(t, err, data.ErrNilCacher) + }) - arg := createMockArgument(time.Minute) - arg.Topic = "" + t.Run("empty topic should fail", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.Equal(t, err, ErrInvalidTrieTopic) -} + arg := createMockArgument(time.Minute) + arg.Topic = "" -func TestNewTrieSyncer_NilTrieStatisticsShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.Equal(t, err, ErrInvalidTrieTopic) + }) - arg := createMockArgument(time.Minute) - arg.TrieSyncStatistics = nil + t.Run("nil trie statistics", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.Equal(t, err, ErrNilTrieSyncStatistics) -} + arg := createMockArgument(time.Minute) + arg.TrieSyncStatistics = nil -func TestNewTrieSyncer_NilDatabaseShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.Equal(t, err, ErrNilTrieSyncStatistics) + }) - arg := createMockArgument(time.Minute) - arg.DB = nil + t.Run("nil database", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.True(t, errors.Is(err, ErrNilDatabase)) -} + arg := createMockArgument(time.Minute) + arg.DB = nil -func TestNewTrieSyncer_NilMarshalizerShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.True(t, errors.Is(err, ErrNilDatabase)) + }) - arg := createMockArgument(time.Minute) - arg.Marshalizer = nil + t.Run("nil marshalizer", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.True(t, errors.Is(err, ErrNilMarshalizer)) -} + arg := createMockArgument(time.Minute) + arg.Marshalizer = nil -func TestNewTrieSyncer_NilHasherShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.True(t, errors.Is(err, ErrNilMarshalizer)) + }) - arg := createMockArgument(time.Minute) - arg.Hasher = nil + t.Run("nil hasher", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.True(t, errors.Is(err, ErrNilHasher)) -} + arg := createMockArgument(time.Minute) + arg.Hasher = nil -func TestNewTrieSyncer_NilTimeoutHandlerShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.True(t, errors.Is(err, ErrNilHasher)) + }) - arg := createMockArgument(time.Minute) - arg.TimeoutHandler = nil + t.Run("nil timeout handler", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.True(t, errors.Is(err, ErrNilTimeoutHandler)) -} + arg := createMockArgument(time.Minute) + arg.TimeoutHandler = nil -func TestNewTrieSyncer_InvalidMaxHardCapForMissingNodesShouldErr(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.True(t, errors.Is(err, ErrNilTimeoutHandler)) + }) - arg := createMockArgument(time.Minute) - arg.MaxHardCapForMissingNodes = 0 + t.Run("invalid max hard capacity for missing nodes", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.True(t, check.IfNil(ts)) - assert.True(t, errors.Is(err, ErrInvalidMaxHardCapForMissingNodes)) -} + arg := createMockArgument(time.Minute) + arg.MaxHardCapForMissingNodes = 0 -func TestNewTrieSyncer_ShouldWork(t *testing.T) { - t.Parallel() + ts, err := NewTrieSyncer(arg) + assert.True(t, check.IfNil(ts)) + assert.True(t, errors.Is(err, ErrInvalidMaxHardCapForMissingNodes)) + }) - arg := createMockArgument(time.Minute) + t.Run("should work", func(t *testing.T) { + t.Parallel() - ts, err := NewTrieSyncer(arg) - assert.False(t, check.IfNil(ts)) - assert.Nil(t, err) + arg := createMockArgument(time.Minute) + + ts, err := NewTrieSyncer(arg) + assert.False(t, check.IfNil(ts)) + assert.Nil(t, err) + }) } func TestTrieSync_InterceptedNodeShouldNotBeAddedToNodesForTrieIfNodeReceived(t *testing.T) { @@ -217,6 +224,10 @@ func TestTrieSync_FoundInStorageShouldNotRequest(t *testing.T) { err = bn.commitSnapshot(db, nil, nil, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, 0) require.Nil(t, err) + leaves, err := bn.getChildren(db) + require.Nil(t, err) + numLeaves := len(leaves) + arg := createMockArgument(timeout) arg.RequestHandler = &testscommon.RequestHandlerStub{ RequestTrieNodesCalled: func(destShardID uint32, hashes [][]byte, topic string) { @@ -232,4 +243,19 @@ func TestTrieSync_FoundInStorageShouldNotRequest(t *testing.T) { err = ts.StartSyncing(rootHash, context.Background()) assert.Nil(t, err) + + wg := &sync.WaitGroup{} + wg.Add(numLeaves) + + numLeavesOnChan := 0 + go func() { + for range arg.LeavesChan { + numLeavesOnChan++ + wg.Done() + } + }() + + wg.Wait() + + assert.Equal(t, numLeaves, numLeavesOnChan) } diff --git a/trie/trieStorageManager.go b/trie/trieStorageManager.go index c5304e45428..99fa6895bb7 100644 --- a/trie/trieStorageManager.go +++ b/trie/trieStorageManager.go @@ -330,19 +330,19 @@ func (tsm *trieStorageManager) TakeSnapshot( ) { if iteratorChannels.ErrChan == nil { log.Error("programming error in trieStorageManager.TakeSnapshot, cannot take snapshot because errChan is nil") - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } if tsm.IsClosed() { - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } if bytes.Equal(rootHash, common.EmptyTrieHash) { log.Trace("should not snapshot an empty trie") - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } @@ -363,7 +363,7 @@ func (tsm *trieStorageManager) TakeSnapshot( case tsm.snapshotReq <- snapshotEntry: case <-tsm.closer.ChanClose(): tsm.ExitPruningBufferingMode() - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() } } @@ -380,19 +380,19 @@ func (tsm *trieStorageManager) SetCheckpoint( ) { if iteratorChannels.ErrChan == nil { log.Error("programming error in trieStorageManager.SetCheckpoint, cannot set checkpoint because errChan is nil") - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } if tsm.IsClosed() { - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } if bytes.Equal(rootHash, common.EmptyTrieHash) { log.Trace("should not set checkpoint for empty trie") - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() return } @@ -410,21 +410,15 @@ func (tsm *trieStorageManager) SetCheckpoint( case tsm.checkpointReq <- checkpointEntry: case <-tsm.closer.ChanClose(): tsm.ExitPruningBufferingMode() - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) stats.SnapshotFinished() } } -func safelyCloseChan(ch chan core.KeyValueHolder) { - if ch != nil { - close(ch) - } -} - func (tsm *trieStorageManager) finishOperation(snapshotEntry *snapshotsQueueEntry, message string) { tsm.ExitPruningBufferingMode() log.Trace(message, "rootHash", snapshotEntry.rootHash) - safelyCloseChan(snapshotEntry.iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(snapshotEntry.iteratorChannels.LeavesChan) snapshotEntry.stats.SnapshotFinished() } diff --git a/trie/trieStorageManagerWithoutCheckpoints.go b/trie/trieStorageManagerWithoutCheckpoints.go index d2f4b93e507..975a9a10111 100644 --- a/trie/trieStorageManagerWithoutCheckpoints.go +++ b/trie/trieStorageManagerWithoutCheckpoints.go @@ -30,7 +30,7 @@ func (tsm *trieStorageManagerWithoutCheckpoints) SetCheckpoint( stats common.SnapshotStatisticsHandler, ) { if iteratorChannels != nil { - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) } stats.SnapshotFinished() diff --git a/trie/trieStorageManagerWithoutSnapshot.go b/trie/trieStorageManagerWithoutSnapshot.go index 337a74f8f9d..7e538eaf184 100644 --- a/trie/trieStorageManagerWithoutSnapshot.go +++ b/trie/trieStorageManagerWithoutSnapshot.go @@ -38,7 +38,7 @@ func (tsm *trieStorageManagerWithoutSnapshot) PutInEpochWithoutCache(key []byte, // TakeSnapshot does nothing, as snapshots are disabled for this implementation func (tsm *trieStorageManagerWithoutSnapshot) TakeSnapshot(_ string, _ []byte, _ []byte, iteratorChannels *common.TrieIteratorChannels, _ chan []byte, stats common.SnapshotStatisticsHandler, _ uint32) { if iteratorChannels != nil { - safelyCloseChan(iteratorChannels.LeavesChan) + common.CloseKeyValueHolderChan(iteratorChannels.LeavesChan) } stats.SnapshotFinished() }