diff --git a/common/channels.go b/common/channels.go index 3d00dcde162..177ac89f5c5 100644 --- a/common/channels.go +++ b/common/channels.go @@ -7,13 +7,3 @@ func GetClosedUnbufferedChannel() chan struct{} { return ch } - -// GetErrorFromChanNonBlocking will get the error from channel -func GetErrorFromChanNonBlocking(errChan chan error) error { - select { - case err := <-errChan: - return err - default: - return nil - } -} diff --git a/common/channels_test.go b/common/channels_test.go index a5fad97d1a4..4e2828e2d6a 100644 --- a/common/channels_test.go +++ b/common/channels_test.go @@ -1,11 +1,8 @@ package common import ( - "errors" "testing" - "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -39,57 +36,3 @@ func didTriggerHappen(ch chan struct{}) bool { return false } } - -func TestErrFromChannel(t *testing.T) { - t.Parallel() - - t.Run("empty channel, should return nil", func(t *testing.T) { - t.Parallel() - - t.Run("unbuffered chan", func(t *testing.T) { - t.Parallel() - - errChan := make(chan error) - assert.Nil(t, GetErrorFromChanNonBlocking(errChan)) - }) - - t.Run("buffered chan", func(t *testing.T) { - t.Parallel() - - errChan := make(chan error, 1) - assert.Nil(t, GetErrorFromChanNonBlocking(errChan)) - }) - }) - - t.Run("non empty channel, should return error", func(t *testing.T) { - t.Parallel() - - t.Run("unbuffered chan", func(t *testing.T) { - t.Parallel() - - expectedErr := errors.New("expected error") - errChan := make(chan error) - go func() { - errChan <- expectedErr - }() - - time.Sleep(time.Second) // allow the go routine to start - - assert.Equal(t, expectedErr, GetErrorFromChanNonBlocking(errChan)) - }) - - t.Run("buffered chan", func(t *testing.T) { - t.Parallel() - - for i := 1; i < 10; i++ { - errChan := make(chan error, i) - expectedErr := errors.New("expected error") - for j := 0; j < i; j++ { - errChan <- expectedErr - } - - assert.Equal(t, expectedErr, GetErrorFromChanNonBlocking(errChan)) - } - }) - }) -} diff --git a/common/errChan/errChan.go b/common/errChan/errChan.go new file mode 100644 index 00000000000..47cf29e320b --- /dev/null +++ b/common/errChan/errChan.go @@ -0,0 +1,69 @@ +package errChan + +import "sync" + +type errChanWrapper struct { + ch chan error + closed bool + closeMutex sync.RWMutex +} + +// NewErrChanWrapper creates a new errChanWrapper +func NewErrChanWrapper() *errChanWrapper { + return &errChanWrapper{ + ch: make(chan error, 1), + closed: false, + } +} + +// WriteInChanNonBlocking will send the given error on the channel if the chan is not blocked +func (ec *errChanWrapper) WriteInChanNonBlocking(err error) { + ec.closeMutex.RLock() + defer ec.closeMutex.RUnlock() + + if ec.closed { + return + } + + select { + case ec.ch <- err: + default: + } +} + +// ReadFromChanNonBlocking will read from the channel, or return nil if no error was sent on the channel +func (ec *errChanWrapper) ReadFromChanNonBlocking() error { + select { + case err := <-ec.ch: + return err + default: + return nil + } +} + +// Close will close the channel +func (ec *errChanWrapper) Close() { + ec.closeMutex.Lock() + defer ec.closeMutex.Unlock() + + if ec.closed { + return + } + + if ec.ch == nil { + return + } + + close(ec.ch) + ec.closed = true +} + +// Len returns the length of the channel +func (ec *errChanWrapper) Len() int { + return len(ec.ch) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (ec *errChanWrapper) IsInterfaceNil() bool { + return ec == nil +} diff --git a/common/errChan/errChan_test.go b/common/errChan/errChan_test.go new file mode 100644 index 00000000000..3d88f358015 --- /dev/null +++ b/common/errChan/errChan_test.go @@ -0,0 +1,136 @@ +package errChan + +import ( + "fmt" + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" +) + +func TestNewErrChan(t *testing.T) { + t.Parallel() + + ec := NewErrChanWrapper() + assert.False(t, check.IfNil(ec)) + assert.Equal(t, 1, cap(ec.ch)) +} + +func TestErrChan_WriteInChanNonBlocking(t *testing.T) { + t.Parallel() + + t.Run("write in a nil channel", func(t *testing.T) { + t.Parallel() + + ec := NewErrChanWrapper() + ec.ch = nil + ec.WriteInChanNonBlocking(fmt.Errorf("err1")) + + assert.Equal(t, 0, len(ec.ch)) + }) + + t.Run("write in a closed channel", func(t *testing.T) { + t.Parallel() + + ec := NewErrChanWrapper() + ec.Close() + ec.WriteInChanNonBlocking(fmt.Errorf("err1")) + + assert.Equal(t, 0, len(ec.ch)) + }) + + t.Run("should work", func(t *testing.T) { + expectedErr := fmt.Errorf("err1") + ec := NewErrChanWrapper() + ec.WriteInChanNonBlocking(expectedErr) + ec.WriteInChanNonBlocking(fmt.Errorf("err2")) + ec.WriteInChanNonBlocking(fmt.Errorf("err3")) + + assert.Equal(t, 1, len(ec.ch)) + assert.Equal(t, expectedErr, <-ec.ch) + assert.Equal(t, 0, len(ec.ch)) + }) +} + +func TestErrChan_ReadFromChanNonBlocking(t *testing.T) { + t.Parallel() + + expectedErr := fmt.Errorf("err1") + ec := NewErrChanWrapper() + ec.ch <- expectedErr + + assert.Equal(t, 1, len(ec.ch)) + assert.Equal(t, expectedErr, ec.ReadFromChanNonBlocking()) + assert.Equal(t, 0, len(ec.ch)) + assert.Nil(t, ec.ReadFromChanNonBlocking()) +} + +func TestErrChan_Close(t *testing.T) { + t.Parallel() + + t.Run("close an already closed channel", func(t *testing.T) { + t.Parallel() + + ec := NewErrChanWrapper() + ec.Close() + + assert.True(t, ec.closed) + ec.Close() + }) + + t.Run("close a nil channel", func(t *testing.T) { + t.Parallel() + + ec := NewErrChanWrapper() + ec.ch = nil + ec.Close() + + assert.False(t, ec.closed) + }) +} + +func TestErrChan_Len(t *testing.T) { + t.Parallel() + + ec := NewErrChanWrapper() + assert.Equal(t, 0, ec.Len()) + + ec.ch <- fmt.Errorf("err1") + assert.Equal(t, 1, ec.Len()) + + ec.WriteInChanNonBlocking(fmt.Errorf("err2")) + assert.Equal(t, 1, ec.Len()) +} + +func TestErrChan_ConcurrentOperations(t *testing.T) { + t.Parallel() + + ec := NewErrChanWrapper() + numOperations := 1000 + numMethods := 2 + wg := sync.WaitGroup{} + wg.Add(numOperations) + for i := 0; i < numOperations; i++ { + go func(idx int) { + + if idx == numOperations-100 { + ec.Close() + } + + operation := idx % numMethods + switch operation { + case 0: + ec.WriteInChanNonBlocking(fmt.Errorf("err")) + case 1: + _ = ec.ReadFromChanNonBlocking() + default: + assert.Fail(t, "invalid numMethods") + } + + wg.Done() + }(i) + } + + wg.Wait() +} diff --git a/common/interface.go b/common/interface.go index a58b6aa94db..16f448179ec 100644 --- a/common/interface.go +++ b/common/interface.go @@ -13,7 +13,15 @@ import ( // TrieIteratorChannels defines the channels that are being used when iterating the trie nodes type TrieIteratorChannels struct { LeavesChan chan core.KeyValueHolder - ErrChan chan error + ErrChan BufferedErrChan +} + +// BufferedErrChan is an interface that defines the methods for a buffered error channel +type BufferedErrChan interface { + WriteInChanNonBlocking(err error) + ReadFromChanNonBlocking() error + Close() + IsInterfaceNil() bool } // Trie is an interface for Merkle Trees implementations diff --git a/debug/process/stateExport.go b/debug/process/stateExport.go index 9fbdd6ce1bc..831aaebfc0e 100644 --- a/debug/process/stateExport.go +++ b/debug/process/stateExport.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/state" ) @@ -66,7 +67,7 @@ func getCode(accountsDB state.AccountsAdapter, codeHash []byte) ([]byte, error) func getData(accountsDB state.AccountsAdapter, rootHash []byte, address []byte) ([]string, error) { leavesChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err := accountsDB.GetAllLeaves(leavesChannels, context.Background(), rootHash) @@ -89,7 +90,7 @@ func getData(accountsDB state.AccountsAdapter, rootHash []byte, address []byte) hex.EncodeToString(valWithoutSuffix))) } - err = <-leavesChannels.ErrChan + err = leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, fmt.Errorf("%w while trying to export data on hex root hash %s, address %s", err, hex.EncodeToString(rootHash), hex.EncodeToString(address)) diff --git a/epochStart/metachain/systemSCs.go b/epochStart/metachain/systemSCs.go index 642053ad7d1..645f54ce3ea 100644 --- a/epochStart/metachain/systemSCs.go +++ b/epochStart/metachain/systemSCs.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" vInfo "github.com/multiversx/mx-chain-go/common/validatorInfo" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/epochStart" @@ -1102,7 +1103,7 @@ func (s *systemSCProcessor) getArgumentsForSetOwnerFunctionality(userValidatorAc leavesChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = userValidatorAccount.DataTrie().GetAllLeavesOnChannel(leavesChannels, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) if err != nil { @@ -1125,7 +1126,7 @@ func (s *systemSCProcessor) getArgumentsForSetOwnerFunctionality(userValidatorAc } } - err = common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan) + err = leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, err } diff --git a/factory/processing/processComponents.go b/factory/processing/processComponents.go index 60860011ef4..f64de2f7447 100644 --- a/factory/processing/processComponents.go +++ b/factory/processing/processComponents.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/outport" nodeFactory "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -874,7 +875,7 @@ func (pcf *processComponentsFactory) indexAndReturnGenesisAccounts() (map[string leavesChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = pcf.state.AccountsAdapter().GetAllLeaves(leavesChannels, context.Background(), rootHash) if err != nil { @@ -901,7 +902,7 @@ func (pcf *processComponentsFactory) indexAndReturnGenesisAccounts() (map[string } } - err = common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan) + err = leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, err } diff --git a/integrationTests/state/stateTrieClose/stateTrieClose_test.go b/integrationTests/state/stateTrieClose/stateTrieClose_test.go index ab18ce244b6..985f49c660a 100644 --- a/integrationTests/state/stateTrieClose/stateTrieClose_test.go +++ b/integrationTests/state/stateTrieClose/stateTrieClose_test.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/testscommon" @@ -36,25 +37,25 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { rootHash, _ := tr.RootHash() leavesChannel1 := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) time.Sleep(time.Second) // allow the go routine to start idx, _ := gc.Snapshot() diff := gc.DiffGoRoutines(idxInitial, idx) assert.True(t, len(diff) <= 1) // can be 0 on a fast running host - err := common.GetErrorFromChanNonBlocking(leavesChannel1.ErrChan) + err := leavesChannel1.ErrChan.ReadFromChanNonBlocking() assert.Nil(t, err) leavesChannel1 = &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) assert.True(t, len(diff) <= 2) - err = common.GetErrorFromChanNonBlocking(leavesChannel1.ErrChan) + err = leavesChannel1.ErrChan.ReadFromChanNonBlocking() assert.Nil(t, err) _ = tr.Update([]byte("god"), []byte("puppy")) @@ -63,13 +64,13 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { rootHash, _ = tr.RootHash() leavesChannel1 = &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } _ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) assert.Equal(t, 3, len(diff), fmt.Sprintf("%v", diff)) - err = common.GetErrorFromChanNonBlocking(leavesChannel1.ErrChan) + err = leavesChannel1.ErrChan.ReadFromChanNonBlocking() assert.Nil(t, err) _ = tr.Update([]byte("eggod"), []byte("cat")) @@ -78,14 +79,14 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { rootHash, _ = tr.RootHash() leavesChannel2 := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } _ = tr.GetAllLeavesOnChannel(leavesChannel2, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) time.Sleep(time.Second) // allow the go routine to start idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) assert.True(t, len(diff) <= 4) - err = common.GetErrorFromChanNonBlocking(leavesChannel2.ErrChan) + err = leavesChannel2.ErrChan.ReadFromChanNonBlocking() assert.Nil(t, err) for range leavesChannel1.LeavesChan { @@ -94,7 +95,7 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) assert.True(t, len(diff) <= 3) - err = common.GetErrorFromChanNonBlocking(leavesChannel1.ErrChan) + err = leavesChannel1.ErrChan.ReadFromChanNonBlocking() assert.Nil(t, err) for range leavesChannel2.LeavesChan { @@ -103,7 +104,7 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { idx, _ = gc.Snapshot() diff = gc.DiffGoRoutines(idxInitial, idx) assert.True(t, len(diff) <= 2) - err = common.GetErrorFromChanNonBlocking(leavesChannel2.ErrChan) + err = leavesChannel2.ErrChan.ReadFromChanNonBlocking() assert.Nil(t, err) err = tr.Close() diff --git a/integrationTests/state/stateTrieSync/stateTrieSync_test.go b/integrationTests/state/stateTrieSync/stateTrieSync_test.go index 0e7387825fd..7b2e28e5866 100644 --- a/integrationTests/state/stateTrieSync/stateTrieSync_test.go +++ b/integrationTests/state/stateTrieSync/stateTrieSync_test.go @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/epochStart/notifier" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/process/factory" @@ -329,13 +330,13 @@ func testMultipleDataTriesSync(t *testing.T, numAccounts int, numDataTrieLeaves rootHash, _ := accState.RootHash() leavesChannel := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = accState.GetAllLeaves(leavesChannel, context.Background(), rootHash) for range leavesChannel.LeavesChan { } require.Nil(t, err) - err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan) + err = leavesChannel.ErrChan.ReadFromChanNonBlocking() require.Nil(t, err) requesterTrie := nRequester.TrieContainer.Get([]byte(trieFactory.UserAccountTrie)) @@ -357,7 +358,7 @@ func testMultipleDataTriesSync(t *testing.T, numAccounts int, numDataTrieLeaves leavesChannel = &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = nRequester.AccntState.GetAllLeaves(leavesChannel, context.Background(), rootHash) assert.Nil(t, err) @@ -365,7 +366,7 @@ func testMultipleDataTriesSync(t *testing.T, numAccounts int, numDataTrieLeaves for range leavesChannel.LeavesChan { numLeaves++ } - err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan) + err = leavesChannel.ErrChan.ReadFromChanNonBlocking() require.Nil(t, err) assert.Equal(t, numAccounts, numLeaves) checkAllDataTriesAreSynced(t, numDataTrieLeaves, requesterTrie, dataTrieRootHashes) @@ -559,7 +560,7 @@ func addAccountsToState(t *testing.T, numAccounts int, numDataTrieLeaves int, ac func getNumLeaves(t *testing.T, tr common.Trie, rootHash []byte) int { leavesChannel := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) require.Nil(t, err) @@ -569,7 +570,7 @@ func getNumLeaves(t *testing.T, tr common.Trie, rootHash []byte) int { numLeaves++ } - err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan) + err = leavesChannel.ErrChan.ReadFromChanNonBlocking() require.Nil(t, err) return numLeaves diff --git a/integrationTests/testProcessorNode.go b/integrationTests/testProcessorNode.go index 532b7a64673..995d1659d8a 100644 --- a/integrationTests/testProcessorNode.go +++ b/integrationTests/testProcessorNode.go @@ -34,6 +34,7 @@ import ( nodeFactory "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus" @@ -3148,7 +3149,7 @@ func GetTokenIdentifier(nodes []*TestProcessorNode, ticker []byte) []byte { rootHash, _ := userAcc.DataTrie().RootHash() chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } _ = userAcc.DataTrie().GetAllLeavesOnChannel(chLeaves, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) for leaf := range chLeaves.LeavesChan { @@ -3159,7 +3160,7 @@ func GetTokenIdentifier(nodes []*TestProcessorNode, ticker []byte) []byte { return leaf.Key() } - err := common.GetErrorFromChanNonBlocking(chLeaves.ErrChan) + err := chLeaves.ErrChan.ReadFromChanNonBlocking() if err != nil { log.Error("error getting all leaves from channel", "err", err) } diff --git a/integrationTests/vm/txsFee/asyncESDT_test.go b/integrationTests/vm/txsFee/asyncESDT_test.go index a4318ad54f0..1f802023506 100644 --- a/integrationTests/vm/txsFee/asyncESDT_test.go +++ b/integrationTests/vm/txsFee/asyncESDT_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests/vm" "github.com/multiversx/mx-chain-go/integrationTests/vm/txsFee/utils" @@ -542,7 +543,7 @@ func TestAsyncESDTCallForThirdContractShouldWork(t *testing.T) { leaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, 1), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = testContext.Accounts.GetAllLeaves(leaves, context.Background(), roothash) require.Nil(t, err) @@ -551,6 +552,6 @@ func TestAsyncESDTCallForThirdContractShouldWork(t *testing.T) { // do nothing, just iterate } - err = <-leaves.ErrChan + err = leaves.ErrChan.ReadFromChanNonBlocking() require.Nil(t, err) } diff --git a/node/node.go b/node/node.go index b6553c127ed..571432a277b 100644 --- a/node/node.go +++ b/node/node.go @@ -22,6 +22,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/transaction" disabledSig "github.com/multiversx/mx-chain-crypto-go/signing/disabled/singlesig" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/debug" "github.com/multiversx/mx-chain-go/facade" @@ -224,7 +225,7 @@ func (n *Node) GetAllIssuedESDTs(tokenType string, ctx context.Context) ([]strin chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { @@ -252,7 +253,7 @@ func (n *Node) GetAllIssuedESDTs(tokenType string, ctx context.Context) ([]strin } } - err = common.GetErrorFromChanNonBlocking(chLeaves.ErrChan) + err = chLeaves.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, err } @@ -305,7 +306,7 @@ func (n *Node) GetKeyValuePairs(address string, options api.AccountQueryOptions, chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { @@ -324,7 +325,7 @@ func (n *Node) GetKeyValuePairs(address string, options api.AccountQueryOptions, mapToReturn[hex.EncodeToString(leaf.Key())] = hex.EncodeToString(value) } - err = common.GetErrorFromChanNonBlocking(chLeaves.ErrChan) + err = chLeaves.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, api.BlockInfo{}, err } @@ -425,7 +426,7 @@ func (n *Node) getTokensIDsWithFilter( chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { @@ -448,7 +449,7 @@ func (n *Node) getTokensIDsWithFilter( } } - err = common.GetErrorFromChanNonBlocking(chLeaves.ErrChan) + err = chLeaves.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, api.BlockInfo{}, err } @@ -566,7 +567,7 @@ func (n *Node) GetAllESDTTokens(address string, options api.AccountQueryOptions, chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = userAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { @@ -604,7 +605,7 @@ func (n *Node) GetAllESDTTokens(address string, options api.AccountQueryOptions, allESDTs[tokenName] = esdtToken } - err = common.GetErrorFromChanNonBlocking(chLeaves.ErrChan) + err = chLeaves.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, api.BlockInfo{}, err } diff --git a/node/node_test.go b/node/node_test.go index c1d1b47a4a4..a9c1e02e4f2 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -443,7 +443,7 @@ func TestNode_GetKeyValuePairs(t *testing.T) { trieLeaf2 := keyValStorage.NewKeyValStorage(k2, append(v2, suffix...)) leavesChannels.LeavesChan <- trieLeaf2 close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil @@ -502,7 +502,7 @@ func TestNode_GetKeyValuePairs_GetAllLeavesShouldFail(t *testing.T) { &trieMock.TrieStub{ GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { - leavesChannels.ErrChan <- expectedErr + leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) close(leavesChannels.LeavesChan) }() @@ -557,7 +557,7 @@ func TestNode_GetKeyValuePairsContextShouldTimeout(t *testing.T) { go func() { time.Sleep(time.Second) close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil @@ -838,7 +838,7 @@ func TestNode_GetAllESDTTokens(t *testing.T) { trieLeaf := keyValStorage.NewKeyValStorage(esdtKey, nil) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil @@ -894,7 +894,7 @@ func TestNode_GetAllESDTTokens_GetAllLeavesShouldFail(t *testing.T) { &trieMock.TrieStub{ GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder) error { go func() { - leavesChannels.ErrChan <- expectedErr + leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) close(leavesChannels.LeavesChan) }() @@ -951,7 +951,7 @@ func TestNode_GetAllESDTTokensContextShouldTimeout(t *testing.T) { go func() { time.Sleep(time.Second) close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil @@ -1083,7 +1083,7 @@ func TestNode_GetAllESDTTokensShouldReturnEsdtAndFormattedNft(t *testing.T) { leavesChannels.LeavesChan <- trieLeaf wg.Done() close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() wg.Wait() @@ -1169,7 +1169,7 @@ func TestNode_GetAllIssuedESDTs(t *testing.T) { trieLeaf = keyValStorage.NewKeyValStorage(nftToken, append(nftMarshalledData, nftSuffix...)) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil @@ -1255,7 +1255,7 @@ func TestNode_GetESDTsWithRole(t *testing.T) { trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil @@ -1335,7 +1335,7 @@ func TestNode_GetESDTsRoles(t *testing.T) { trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil @@ -1400,7 +1400,7 @@ func TestNode_GetNFTTokenIDsRegisteredByAddress(t *testing.T) { trieLeaf := keyValStorage.NewKeyValStorage(esdtToken, append(marshalledData, esdtSuffix...)) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil @@ -1457,7 +1457,7 @@ func TestNode_GetNFTTokenIDsRegisteredByAddressContextShouldTimeout(t *testing.T go func() { time.Sleep(time.Second) close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil diff --git a/node/trieIterators/delegatedListProcessor.go b/node/trieIterators/delegatedListProcessor.go index 5db7ecb4116..acf6c763128 100644 --- a/node/trieIterators/delegatedListProcessor.go +++ b/node/trieIterators/delegatedListProcessor.go @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/trie/keyBuilder" @@ -129,7 +130,7 @@ func (dlp *delegatedListProcessor) getDelegatorsList(delegationSC []byte, ctx co chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = delegatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { @@ -146,7 +147,7 @@ func (dlp *delegatedListProcessor) getDelegatorsList(delegationSC []byte, ctx co delegators = append(delegators, leafKey) } - err = common.GetErrorFromChanNonBlocking(chLeaves.ErrChan) + err = chLeaves.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, err } diff --git a/node/trieIterators/delegatedListProcessor_test.go b/node/trieIterators/delegatedListProcessor_test.go index 090f8ce68e1..c669b43924e 100644 --- a/node/trieIterators/delegatedListProcessor_test.go +++ b/node/trieIterators/delegatedListProcessor_test.go @@ -232,7 +232,7 @@ func createDelegationScAccount(address []byte, leaves [][]byte, rootHash []byte, } close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil diff --git a/node/trieIterators/directStakedListProcessor.go b/node/trieIterators/directStakedListProcessor.go index 0ff046919b4..884607e7d7f 100644 --- a/node/trieIterators/directStakedListProcessor.go +++ b/node/trieIterators/directStakedListProcessor.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/vm" @@ -56,7 +57,7 @@ func (dslp *directStakedListProcessor) getAllStakedAccounts(validatorAccount sta chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = validatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { @@ -90,7 +91,7 @@ func (dslp *directStakedListProcessor) getAllStakedAccounts(validatorAccount sta stakedAccounts = append(stakedAccounts, val) } - err = common.GetErrorFromChanNonBlocking(chLeaves.ErrChan) + err = chLeaves.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, err } diff --git a/node/trieIterators/directStakedListProcessor_test.go b/node/trieIterators/directStakedListProcessor_test.go index 18b0bba952d..29e19f82542 100644 --- a/node/trieIterators/directStakedListProcessor_test.go +++ b/node/trieIterators/directStakedListProcessor_test.go @@ -162,7 +162,7 @@ func createValidatorScAccount(address []byte, leaves [][]byte, rootHash []byte, } close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil diff --git a/node/trieIterators/stakeValuesProcessor.go b/node/trieIterators/stakeValuesProcessor.go index c77169203d3..17109690b98 100644 --- a/node/trieIterators/stakeValuesProcessor.go +++ b/node/trieIterators/stakeValuesProcessor.go @@ -10,6 +10,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/trie/keyBuilder" @@ -98,7 +99,7 @@ func (svp *stakedValuesProcessor) computeBaseStakedAndTopUp(ctx context.Context) // TODO investigate if a call to GetAllLeavesKeysOnChannel (without values) might increase performance chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = validatorAccount.DataTrie().GetAllLeavesOnChannel(chLeaves, ctx, rootHash, keyBuilder.NewKeyBuilder()) if err != nil { @@ -123,7 +124,7 @@ func (svp *stakedValuesProcessor) computeBaseStakedAndTopUp(ctx context.Context) totalTopUp = totalTopUp.Add(totalTopUp, info.topUpValue) } - err = common.GetErrorFromChanNonBlocking(chLeaves.ErrChan) + err = chLeaves.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, nil, err } diff --git a/node/trieIterators/stakeValuesProcessor_test.go b/node/trieIterators/stakeValuesProcessor_test.go index 166b4fc37f0..6a81e0ddd76 100644 --- a/node/trieIterators/stakeValuesProcessor_test.go +++ b/node/trieIterators/stakeValuesProcessor_test.go @@ -195,7 +195,7 @@ func TestTotalStakedValueProcessor_GetTotalStakedValue_ContextShouldTimeout(t *t GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, _ context.Context, _ []byte, _ common.KeyBuilder) error { time.Sleep(time.Second) close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() return nil }, RootCalled: func() ([]byte, error) { @@ -297,7 +297,7 @@ func TestTotalStakedValueProcessor_GetTotalStakedValue(t *testing.T) { channels.LeavesChan <- leaf6 close(channels.LeavesChan) - close(channels.ErrChan) + channels.ErrChan.Close() }() return nil diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index a9c47516a55..02ca6f7652e 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -21,6 +21,7 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" nodeFactory "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/holders" "github.com/multiversx/mx-chain-go/common/logging" "github.com/multiversx/mx-chain-go/config" @@ -1733,7 +1734,7 @@ func (bp *baseProcessor) commitTrieEpochRootHashIfNeeded(metaBlock *block.MetaBl iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = userAccountsDb.GetAllLeaves(iteratorChannels, context.Background(), rootHash) if err != nil { @@ -1762,7 +1763,7 @@ func (bp *baseProcessor) commitTrieEpochRootHashIfNeeded(metaBlock *block.MetaBl if len(rh) != 0 { dataTrie := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } errDataTrieGet := userAccountsDb.GetAllLeaves(dataTrie, context.Background(), rh) if errDataTrieGet != nil { @@ -1774,7 +1775,7 @@ func (bp *baseProcessor) commitTrieEpochRootHashIfNeeded(metaBlock *block.MetaBl currentSize += len(lf.Value()) } - err = common.GetErrorFromChanNonBlocking(dataTrie.ErrChan) + err = dataTrie.ErrChan.ReadFromChanNonBlocking() if err != nil { return err } @@ -1790,7 +1791,7 @@ func (bp *baseProcessor) commitTrieEpochRootHashIfNeeded(metaBlock *block.MetaBl balanceSum.Add(balanceSum, userAccount.GetBalance()) } - err = common.GetErrorFromChanNonBlocking(iteratorChannels.ErrChan) + err = iteratorChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return err } diff --git a/process/block/baseProcess_test.go b/process/block/baseProcess_test.go index a8525909b4f..ba89195248f 100644 --- a/process/block/baseProcess_test.go +++ b/process/block/baseProcess_test.go @@ -1892,7 +1892,7 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeededShouldWork(t *testing.T) { }, GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { close(channels.LeavesChan) - close(channels.ErrChan) + channels.ErrChan.Close() return nil }, }, @@ -1936,7 +1936,7 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeeded_GetAllLeaves(t *testing.T }, GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { close(channels.LeavesChan) - close(channels.ErrChan) + channels.ErrChan.Close() return expectedErr }, }, @@ -1973,7 +1973,7 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeeded_GetAllLeaves(t *testing.T return rootHash, nil }, GetAllLeavesCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { - channels.ErrChan <- expectedErr + channels.ErrChan.WriteInChanNonBlocking(expectedErr) close(channels.LeavesChan) return nil }, @@ -2033,14 +2033,14 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeededShouldUseDataTrieIfNeededW if bytes.Equal(rootHash, rh) { calledWithUserAccountRootHash = true close(channels.LeavesChan) - close(channels.ErrChan) + channels.ErrChan.Close() return nil } go func() { channels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("address"), []byte("bytes")) close(channels.LeavesChan) - close(channels.ErrChan) + channels.ErrChan.Close() }() return nil diff --git a/process/peer/process.go b/process/peer/process.go index d5ed5d06b2e..3eac66835a8 100644 --- a/process/peer/process.go +++ b/process/peer/process.go @@ -14,6 +14,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/validatorInfo" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" @@ -447,7 +448,7 @@ func (vs *validatorStatistics) getValidatorDataFromLeaves( validators[currentShardId] = append(validators[currentShardId], validatorInfoData) } - err := common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan) + err := leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, err } @@ -562,7 +563,7 @@ func (vs *validatorStatistics) GetValidatorInfoForRootHash(rootHash []byte) (map leavesChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err := vs.peerAdapter.GetAllLeaves(leavesChannels, context.Background(), rootHash) if err != nil { diff --git a/process/peer/process_test.go b/process/peer/process_test.go index ee1bab03e7f..5f84e61d749 100644 --- a/process/peer/process_test.go +++ b/process/peer/process_test.go @@ -1969,7 +1969,7 @@ func TestValidatorStatistics_ResetValidatorStatisticsAtNewEpoch(t *testing.T) { go func() { ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytes0, marshalizedPa0) close(ch.LeavesChan) - close(ch.ErrChan) + ch.ErrChan.Close() }() return nil @@ -2032,7 +2032,7 @@ func TestValidatorStatistics_Process(t *testing.T) { ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytes0, marshalizedPa0) ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytesMeta, marshalizedPaMeta) close(ch.LeavesChan) - close(ch.ErrChan) + ch.ErrChan.Close() }() return nil @@ -2078,7 +2078,7 @@ func TestValidatorStatistics_GetValidatorInfoForRootHash(t *testing.T) { peerAdapter.GetAllLeavesCalled = func(ch *common.TrieIteratorChannels, ctx context.Context, rootHash []byte) error { if bytes.Equal(rootHash, hash) { go func() { - ch.ErrChan <- expectedErr + ch.ErrChan.WriteInChanNonBlocking(expectedErr) close(ch.LeavesChan) }() @@ -2108,7 +2108,7 @@ func TestValidatorStatistics_GetValidatorInfoForRootHash(t *testing.T) { ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytes0, marshalizedPa0) ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytesMeta, marshalizedPaMeta) close(ch.LeavesChan) - close(ch.ErrChan) + ch.ErrChan.Close() }() return nil @@ -2555,7 +2555,7 @@ func updateArgumentsWithNeeded(arguments peer.ArgValidatorStatisticsProcessor) { ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytes0, marshalizedPa0) ch.LeavesChan <- keyValStorage.NewKeyValStorage(addrBytesMeta, marshalizedPaMeta) close(ch.LeavesChan) - close(ch.ErrChan) + ch.ErrChan.Close() }() return nil diff --git a/process/txsimulator/wrappedAccountsDB_test.go b/process/txsimulator/wrappedAccountsDB_test.go index 1bf48e18531..e83fe6a0d58 100644 --- a/process/txsimulator/wrappedAccountsDB_test.go +++ b/process/txsimulator/wrappedAccountsDB_test.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/state" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -150,11 +151,11 @@ func TestReadOnlyAccountsDB_ReadOperationsShouldWork(t *testing.T) { allLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = roAccDb.GetAllLeaves(allLeaves, context.Background(), nil) require.NoError(t, err) - err = common.GetErrorFromChanNonBlocking(allLeaves.ErrChan) + err = allLeaves.ErrChan.ReadFromChanNonBlocking() require.NoError(t, err) } diff --git a/state/accountsDB.go b/state/accountsDB.go index 48a57964164..5facd718da4 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/holders" "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/trie/keyBuilder" @@ -1039,7 +1040,7 @@ func (adb *AccountsDB) recreateTrie(options common.RootHashHolder) error { func (adb *AccountsDB) RecreateAllTries(rootHash []byte) (map[string]common.Trie, error) { leavesChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, leavesChannelSize), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } mainTrie := adb.getMainTrie() err := mainTrie.GetAllLeavesOnChannel(leavesChannels, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) @@ -1070,7 +1071,7 @@ func (adb *AccountsDB) RecreateAllTries(rootHash []byte) (map[string]common.Trie } } - err = common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan) + err = leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, err } @@ -1148,7 +1149,7 @@ func (adb *AccountsDB) SnapshotState(rootHash []byte) { missingNodesChannel := make(chan []byte, missingNodesChannelSize) iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, leavesChannelSize), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } stats := newSnapshotStatistics(1, 1) @@ -1259,7 +1260,7 @@ func (adb *AccountsDB) processSnapshotCompletion( stats *snapshotStatistics, trieStorageManager common.StorageManager, missingNodesCh chan []byte, - errChan chan error, + errChan common.BufferedErrChan, rootHash []byte, metrics *accountMetrics, epoch uint32, @@ -1269,15 +1270,15 @@ func (adb *AccountsDB) processSnapshotCompletion( defer func() { adb.isSnapshotInProgress.Reset() adb.updateMetricsOnSnapshotCompletion(metrics, stats) - close(errChan) + errChan.Close() }() - containsErrorDuringSnapshot := emptyErrChanReturningHadContained(errChan) - shouldNotMarkActive := trieStorageManager.IsClosed() || containsErrorDuringSnapshot + errorDuringSnapshot := errChan.ReadFromChanNonBlocking() + shouldNotMarkActive := trieStorageManager.IsClosed() || errorDuringSnapshot != nil if shouldNotMarkActive { log.Debug("will not set activeDB in epoch as the snapshot might be incomplete", "epoch", epoch, "trie storage manager closed", trieStorageManager.IsClosed(), - "errors during snapshot found", containsErrorDuringSnapshot) + "errors during snapshot found", errorDuringSnapshot) return } @@ -1289,7 +1290,7 @@ func (adb *AccountsDB) processSnapshotCompletion( handleLoggingWhenError("error while putting active DB value into main storer", errPut) } -func (adb *AccountsDB) syncMissingNodes(missingNodesChan chan []byte, errChan chan error, stats *snapshotStatistics, syncer AccountsDBSyncer) { +func (adb *AccountsDB) syncMissingNodes(missingNodesChan chan []byte, errChan common.BufferedErrChan, stats *snapshotStatistics, syncer AccountsDBSyncer) { defer stats.SyncFinished() if check.IfNil(syncer) { @@ -1297,7 +1298,7 @@ func (adb *AccountsDB) syncMissingNodes(missingNodesChan chan []byte, errChan ch for missingNode := range missingNodesChan { log.Warn("could not sync node", "hash", missingNode) } - errChan <- ErrNilTrieSyncer + errChan.WriteInChanNonBlocking(ErrNilTrieSyncer) return } @@ -1308,7 +1309,7 @@ func (adb *AccountsDB) syncMissingNodes(missingNodesChan chan []byte, errChan ch "missing node hash", missingNode, "error", err, ) - errChan <- err + errChan.WriteInChanNonBlocking(err) } } } @@ -1376,7 +1377,7 @@ func (adb *AccountsDB) setStateCheckpoint(rootHash []byte) { iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, leavesChannelSize), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } missingNodesChannel := make(chan []byte, missingNodesChannelSize) stats := newSnapshotStatistics(1, 1) @@ -1442,7 +1443,7 @@ func (adb *AccountsDB) GetStatsForRootHash(rootHash []byte) (common.TriesStatist iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, leavesChannelSize), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err := mainTrie.GetAllLeavesOnChannel(iteratorChannels, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder()) if err != nil { @@ -1465,7 +1466,7 @@ func (adb *AccountsDB) GetStatsForRootHash(rootHash []byte) (common.TriesStatist collectStats(tr, stats, account.RootHash, address) } - err = common.GetErrorFromChanNonBlocking(iteratorChannels.ErrChan) + err = iteratorChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, err } diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index e90d9636d18..a70f2a1fff3 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/keyValStorage" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/holders" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process/mock" @@ -1027,7 +1028,7 @@ func TestAccountsDB_SnapshotStateWithErrorsShouldNotMarkActiveDB(t *testing.T) { return true }, TakeSnapshotCalled: func(_ string, _ []byte, _ []byte, iteratorChannels *common.TrieIteratorChannels, _ chan []byte, stats common.SnapshotStatisticsHandler, _ uint32) { - iteratorChannels.ErrChan <- expectedErr + iteratorChannels.ErrChan.WriteInChanNonBlocking(expectedErr) close(iteratorChannels.LeavesChan) stats.SnapshotFinished() }, @@ -1428,7 +1429,7 @@ func TestAccountsDB_GetAllLeaves(t *testing.T) { GetAllLeavesOnChannelCalled: func(channels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, builder common.KeyBuilder) error { getAllLeavesCalled = true close(channels.LeavesChan) - close(channels.ErrChan) + channels.ErrChan.Close() return nil }, @@ -1441,13 +1442,13 @@ func TestAccountsDB_GetAllLeaves(t *testing.T) { leavesChannel := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err := adb.GetAllLeaves(leavesChannel, context.Background(), []byte("root hash")) assert.Nil(t, err) assert.True(t, getAllLeavesCalled) - err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan) + err = leavesChannel.ErrChan.ReadFromChanNonBlocking() assert.Nil(t, err) } @@ -2324,10 +2325,10 @@ func TestAccountsDB_RecreateAllTries(t *testing.T) { GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder) error { go func() { leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("key"), []byte("val")) - leavesChannels.ErrChan <- expectedErr + leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil @@ -2355,7 +2356,7 @@ func TestAccountsDB_RecreateAllTries(t *testing.T) { leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("key"), []byte("val")) close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil @@ -2733,17 +2734,17 @@ func TestEmptyErrChanReturningHadContained(t *testing.T) { t.Run("unbuffered chan", func(t *testing.T) { t.Parallel() - errChan := make(chan error) - assert.False(t, state.EmptyErrChanReturningHadContained(errChan)) - assert.Equal(t, 0, len(errChan)) + errChannel := make(chan error) + assert.False(t, state.EmptyErrChanReturningHadContained(errChannel)) + assert.Equal(t, 0, len(errChannel)) }) t.Run("buffered chan", func(t *testing.T) { t.Parallel() for i := 1; i < 10; i++ { - errChan := make(chan error, i) - assert.False(t, state.EmptyErrChanReturningHadContained(errChan)) - assert.Equal(t, 0, len(errChan)) + errChannel := make(chan error, i) + assert.False(t, state.EmptyErrChanReturningHadContained(errChannel)) + assert.Equal(t, 0, len(errChannel)) } }) }) @@ -2753,27 +2754,27 @@ func TestEmptyErrChanReturningHadContained(t *testing.T) { t.Run("unbuffered chan", func(t *testing.T) { t.Parallel() - errChan := make(chan error) + errChannel := make(chan error) go func() { - errChan <- errors.New("test") + errChannel <- errors.New("test") }() time.Sleep(time.Second) // allow the go routine to start - assert.True(t, state.EmptyErrChanReturningHadContained(errChan)) - assert.Equal(t, 0, len(errChan)) + assert.True(t, state.EmptyErrChanReturningHadContained(errChannel)) + assert.Equal(t, 0, len(errChannel)) }) t.Run("buffered chan", func(t *testing.T) { t.Parallel() for i := 1; i < 10; i++ { - errChan := make(chan error, i) + errChannel := make(chan error, i) for j := 0; j < i; j++ { - errChan <- errors.New("test") + errChannel <- errors.New("test") } - assert.True(t, state.EmptyErrChanReturningHadContained(errChan)) - assert.Equal(t, 0, len(errChan)) + assert.True(t, state.EmptyErrChanReturningHadContained(errChannel)) + assert.Equal(t, 0, len(errChannel)) } }) }) @@ -2900,6 +2901,39 @@ func TestAccountsDB_SyncMissingSnapshotNodes(t *testing.T) { assert.True(t, isMissingNodeCalled) }) + + t.Run("should not deadlock if sync err after another err", func(t *testing.T) { + t.Parallel() + + missingNodeError := errors.New("missing trie node") + isMissingNodeCalled := false + + memDbMock := testscommon.NewMemDbMock() + memDbMock.PutCalled = func(key, val []byte) error { + return fmt.Errorf("put error") + } + memDbMock.GetCalled = func(key []byte) ([]byte, error) { + if bytes.Equal(key, []byte(common.ActiveDBKey)) { + return []byte(common.ActiveDBVal), nil + } + + isMissingNodeCalled = true + return nil, missingNodeError + } + + tr, adb := getDefaultTrieAndAccountsDbWithCustomDB(&testscommon.SnapshotPruningStorerMock{MemDbMock: memDbMock}) + prepareTrie(tr, 3) + + rootHash, _ := tr.RootHash() + + adb.SnapshotState(rootHash) + + for tr.GetStorageManager().IsPruningBlocked() { + time.Sleep(time.Millisecond * 100) + } + + assert.True(t, isMissingNodeCalled) + }) } func prepareTrie(tr common.Trie, numKeys int) { diff --git a/state/peerAccountsDB.go b/state/peerAccountsDB.go index ed1f080069e..171ab6e3d06 100644 --- a/state/peerAccountsDB.go +++ b/state/peerAccountsDB.go @@ -2,6 +2,7 @@ package state import ( "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" ) // PeerAccountsDB will save and synchronize data from peer processor, plus will synchronize with nodesCoordinator @@ -56,7 +57,7 @@ func (adb *PeerAccountsDB) SnapshotState(rootHash []byte) { missingNodesChannel := make(chan []byte, missingNodesChannelSize) iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: nil, - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } stats := newSnapshotStatistics(0, 1) stats.NewSnapshotStarted() @@ -92,7 +93,7 @@ func (adb *PeerAccountsDB) SetStateCheckpoint(rootHash []byte) { stats.NewSnapshotStarted() iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: nil, - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } trieStorageManager.SetCheckpoint(rootHash, rootHash, iteratorChannels, missingNodesChannel, stats) diff --git a/state/syncer/userAccountsSyncer.go b/state/syncer/userAccountsSyncer.go index ca2e1142266..2e4f7f1f156 100644 --- a/state/syncer/userAccountsSyncer.go +++ b/state/syncer/userAccountsSyncer.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" @@ -213,7 +214,7 @@ func (u *userAccountsSyncer) syncAccountDataTries( leavesChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = mainTrie.GetAllLeavesOnChannel(leavesChannels, context.Background(), mainRootHash, keyBuilder.NewDisabledKeyBuilder()) if err != nil { @@ -265,7 +266,7 @@ func (u *userAccountsSyncer) syncAccountDataTries( wg.Wait() - err = common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan) + err = leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return err } diff --git a/trie/export_test.go b/trie/export_test.go index 66c17dd56c3..83fa38f3c8f 100644 --- a/trie/export_test.go +++ b/trie/export_test.go @@ -75,7 +75,10 @@ func GetDirtyHashes(tr common.Trie) common.ModifiedHashes { // WriteInChanNonBlocking - func WriteInChanNonBlocking(errChan chan error, err error) { - writeInChanNonBlocking(errChan, err) + select { + case errChan <- err: + default: + } } type StorageManagerExtensionStub struct { diff --git a/trie/node_test.go b/trie/node_test.go index a28a2c92c2c..f6bfcf165ce 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -10,6 +10,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" dataMock "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/trie/keyBuilder" @@ -522,7 +523,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesCollapsedTrie(t *testing.T) { leavesChannel := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), tr.root.getHash(), keyBuilder.NewKeyBuilder()) assert.Nil(t, err) @@ -532,7 +533,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesCollapsedTrie(t *testing.T) { leaves[string(l.Key())] = l.Value() } - err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan) + err = leavesChannel.ErrChan.ReadFromChanNonBlocking() assert.Nil(t, err) assert.Equal(t, 3, len(leaves)) diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index 59b6d988f59..9a742683858 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -459,14 +459,14 @@ func (tr *patriciaMerkleTrie) GetAllLeavesOnChannel( if err != nil { tr.mutOperation.RUnlock() close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() return err } if check.IfNil(newTrie) || newTrie.root == nil { tr.mutOperation.RUnlock() close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() return nil } @@ -483,7 +483,7 @@ func (tr *patriciaMerkleTrie) GetAllLeavesOnChannel( ctx, ) if err != nil { - writeInChanNonBlocking(leavesChannels.ErrChan, err) + leavesChannels.ErrChan.WriteInChanNonBlocking(err) log.Error("could not get all trie leaves: ", "error", err) } @@ -492,7 +492,7 @@ func (tr *patriciaMerkleTrie) GetAllLeavesOnChannel( tr.mutOperation.Unlock() close(leavesChannels.LeavesChan) - close(leavesChannels.ErrChan) + leavesChannels.ErrChan.Close() }() return nil diff --git a/trie/patriciaMerkleTrie_test.go b/trie/patriciaMerkleTrie_test.go index 45b9066e490..9ec956e96cf 100644 --- a/trie/patriciaMerkleTrie_test.go +++ b/trie/patriciaMerkleTrie_test.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing/keccak" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/holders" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/testscommon" @@ -475,7 +476,7 @@ func TestPatriciaMerkleTrie_GetSerializedNodesGetFromCheckpoint(t *testing.T) { storageManager.AddDirtyCheckpointHashes(rootHash, dirtyHashes) iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: nil, - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } storageManager.SetCheckpoint(rootHash, make([]byte, 0), iteratorChannels, nil, &trieMock.MockStatistics{}) trie.WaitForOperationToComplete(storageManager) @@ -562,7 +563,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: nil, - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err := tr.GetAllLeavesOnChannel(iteratorChannels, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder()) assert.Equal(t, trie.ErrNilTrieIteratorLeavesChannel, err) @@ -588,7 +589,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { leavesChannel := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), []byte{}, keyBuilder.NewDisabledKeyBuilder()) assert.Nil(t, err) @@ -597,7 +598,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { _, ok := <-leavesChannel.LeavesChan assert.False(t, ok) - err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan) + err = leavesChannel.ErrChan.ReadFromChanNonBlocking() assert.Nil(t, err) }) @@ -610,7 +611,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { leavesChannel := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } expectedErr := errors.New("expected error") @@ -630,7 +631,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { for leaf := range leavesChannel.LeavesChan { recovered[string(leaf.Key())] = leaf.Value() } - err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan) + err = leavesChannel.ErrChan.ReadFromChanNonBlocking() assert.Equal(t, expectedErr, err) assert.Equal(t, 0, len(recovered)) }) @@ -646,7 +647,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { leavesChannel := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } expectedErr := errors.New("expected error") @@ -672,7 +673,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { for leaf := range leavesChannel.LeavesChan { recovered[string(leaf.Key())] = leaf.Value() } - err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan) + err = leavesChannel.ErrChan.ReadFromChanNonBlocking() assert.Equal(t, expectedErr, err) expectedLeaves := map[string][]byte{ @@ -695,7 +696,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { leavesChannel := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) assert.Nil(t, err) @@ -705,7 +706,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { for leaf := range leavesChannel.LeavesChan { recovered[string(leaf.Key())] = leaf.Value() } - err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan) + err = leavesChannel.ErrChan.ReadFromChanNonBlocking() assert.Nil(t, err) assert.Equal(t, leaves, recovered) }) diff --git a/trie/trieStorageManager.go b/trie/trieStorageManager.go index dc50faff711..c5304e45428 100644 --- a/trie/trieStorageManager.go +++ b/trie/trieStorageManager.go @@ -438,7 +438,7 @@ func (tsm *trieStorageManager) takeSnapshot(snapshotEntry *snapshotsQueueEntry, stsm, err := newSnapshotTrieStorageManager(tsm, snapshotEntry.epoch) if err != nil { - writeInChanNonBlocking(snapshotEntry.iteratorChannels.ErrChan, err) + snapshotEntry.iteratorChannels.ErrChan.WriteInChanNonBlocking(err) log.Error("takeSnapshot: trie storage manager: newSnapshotTrieStorageManager", "rootHash", snapshotEntry.rootHash, "main trie rootHash", snapshotEntry.mainTrieRootHash, @@ -448,7 +448,7 @@ func (tsm *trieStorageManager) takeSnapshot(snapshotEntry *snapshotsQueueEntry, newRoot, err := newSnapshotNode(stsm, msh, hsh, snapshotEntry.rootHash, snapshotEntry.missingNodesChan) if err != nil { - writeInChanNonBlocking(snapshotEntry.iteratorChannels.ErrChan, err) + snapshotEntry.iteratorChannels.ErrChan.WriteInChanNonBlocking(err) treatSnapshotError(err, "trie storage manager: newSnapshotNode takeSnapshot", snapshotEntry.rootHash, @@ -460,7 +460,7 @@ func (tsm *trieStorageManager) takeSnapshot(snapshotEntry *snapshotsQueueEntry, stats := statistics.NewTrieStatistics() err = newRoot.commitSnapshot(stsm, snapshotEntry.iteratorChannels.LeavesChan, snapshotEntry.missingNodesChan, ctx, stats, tsm.idleProvider, rootDepthLevel) if err != nil { - writeInChanNonBlocking(snapshotEntry.iteratorChannels.ErrChan, err) + snapshotEntry.iteratorChannels.ErrChan.WriteInChanNonBlocking(err) treatSnapshotError(err, "trie storage manager: takeSnapshot commit", snapshotEntry.rootHash, @@ -473,13 +473,6 @@ func (tsm *trieStorageManager) takeSnapshot(snapshotEntry *snapshotsQueueEntry, snapshotEntry.stats.AddTrieStats(stats.GetTrieStats()) } -func writeInChanNonBlocking(errChan chan error, err error) { - select { - case errChan <- err: - default: - } -} - func (tsm *trieStorageManager) takeCheckpoint(checkpointEntry *snapshotsQueueEntry, msh marshal.Marshalizer, hsh hashing.Hasher, ctx context.Context, goRoutinesThrottler core.Throttler) { defer func() { tsm.finishOperation(checkpointEntry, "trie checkpoint finished") @@ -490,7 +483,7 @@ func (tsm *trieStorageManager) takeCheckpoint(checkpointEntry *snapshotsQueueEnt newRoot, err := newSnapshotNode(tsm, msh, hsh, checkpointEntry.rootHash, checkpointEntry.missingNodesChan) if err != nil { - writeInChanNonBlocking(checkpointEntry.iteratorChannels.ErrChan, err) + checkpointEntry.iteratorChannels.ErrChan.WriteInChanNonBlocking(err) treatSnapshotError(err, "trie storage manager: newSnapshotNode takeCheckpoint", checkpointEntry.rootHash, @@ -502,7 +495,7 @@ func (tsm *trieStorageManager) takeCheckpoint(checkpointEntry *snapshotsQueueEnt stats := statistics.NewTrieStatistics() err = newRoot.commitCheckpoint(tsm, tsm.checkpointsStorer, tsm.checkpointHashesHolder, checkpointEntry.iteratorChannels.LeavesChan, ctx, stats, tsm.idleProvider, rootDepthLevel) if err != nil { - writeInChanNonBlocking(checkpointEntry.iteratorChannels.ErrChan, err) + checkpointEntry.iteratorChannels.ErrChan.WriteInChanNonBlocking(err) treatSnapshotError(err, "trie storage manager: takeCheckpoint commit", checkpointEntry.rootHash, diff --git a/trie/trieStorageManagerFactory_test.go b/trie/trieStorageManagerFactory_test.go index d5a28801d9c..8045a06d707 100644 --- a/trie/trieStorageManagerFactory_test.go +++ b/trie/trieStorageManagerFactory_test.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/testscommon" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" @@ -134,7 +135,7 @@ func TestTrieStorageManager_SerialFuncShadowingCallsExpectedImpl(t *testing.T) { iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } tsm.SetCheckpoint(nil, nil, iteratorChannels, nil, &trieMock.MockStatistics{}) @@ -167,7 +168,7 @@ func testTsmWithoutSnapshot( iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } tsm.TakeSnapshot("", nil, nil, iteratorChannels, nil, &trieMock.MockStatistics{}, 10) diff --git a/trie/trieStorageManagerWithoutCheckpoints_test.go b/trie/trieStorageManagerWithoutCheckpoints_test.go index 0f3cf254a77..891a14a392e 100644 --- a/trie/trieStorageManagerWithoutCheckpoints_test.go +++ b/trie/trieStorageManagerWithoutCheckpoints_test.go @@ -5,6 +5,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" "github.com/stretchr/testify/assert" @@ -27,14 +28,14 @@ func TestTrieStorageManagerWithoutCheckpoints_SetCheckpoint(t *testing.T) { iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: nil, - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } ts.SetCheckpoint([]byte("rootHash"), make([]byte, 0), iteratorChannels, nil, &trieMock.MockStatistics{}) assert.Equal(t, uint32(0), ts.PruningBlockingOperations()) iteratorChannels = &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } ts.SetCheckpoint([]byte("rootHash"), make([]byte, 0), iteratorChannels, nil, &trieMock.MockStatistics{}) assert.Equal(t, uint32(0), ts.PruningBlockingOperations()) diff --git a/trie/trieStorageManagerWithoutSnapshot_test.go b/trie/trieStorageManagerWithoutSnapshot_test.go index 4077c71978a..309e328433f 100644 --- a/trie/trieStorageManagerWithoutSnapshot_test.go +++ b/trie/trieStorageManagerWithoutSnapshot_test.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" "github.com/stretchr/testify/assert" @@ -79,7 +80,7 @@ func TestTrieStorageManagerWithoutSnapshot_TakeSnapshot(t *testing.T) { iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } ts.TakeSnapshot("", nil, nil, iteratorChannels, nil, &trieMock.MockStatistics{}, 10) diff --git a/trie/trieStorageManager_test.go b/trie/trieStorageManager_test.go index f634024514d..a0b5a88ce63 100644 --- a/trie/trieStorageManager_test.go +++ b/trie/trieStorageManager_test.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/testscommon" @@ -36,6 +37,12 @@ func getNewTrieStorageManagerArgs() trie.NewTrieStorageManagerArgs { } } +// errChanWithLen extends the BufferedErrChan interface with a Len method +type errChanWithLen interface { + common.BufferedErrChan + Len() int +} + func TestNewTrieStorageManager(t *testing.T) { t.Parallel() @@ -91,7 +98,7 @@ func TestTrieCheckpoint(t *testing.T) { trieStorage.AddDirtyCheckpointHashes(rootHash, dirtyHashes) iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: nil, - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } trieStorage.SetCheckpoint(rootHash, []byte{}, iteratorChannels, nil, &trieMock.MockStatistics{}) trie.WaitForOperationToComplete(trieStorage) @@ -99,7 +106,10 @@ func TestTrieCheckpoint(t *testing.T) { val, err = trieStorage.GetFromCheckpoint(rootHash) assert.Nil(t, err) assert.NotNil(t, val) - assert.Equal(t, 0, len(iteratorChannels.ErrChan)) + + ch, ok := iteratorChannels.ErrChan.(errChanWithLen) + assert.True(t, ok) + assert.Equal(t, 0, ch.Len()) } func TestTrieStorageManager_SetCheckpointNilErrorChan(t *testing.T) { @@ -131,13 +141,15 @@ func TestTrieStorageManager_SetCheckpointClosedDb(t *testing.T) { rootHash := []byte("rootHash") iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } ts.SetCheckpoint(rootHash, rootHash, iteratorChannels, nil, &trieMock.MockStatistics{}) _, ok := <-iteratorChannels.LeavesChan assert.False(t, ok) - assert.Equal(t, 0, len(iteratorChannels.ErrChan)) + ch, ok := iteratorChannels.ErrChan.(errChanWithLen) + assert.True(t, ok) + assert.Equal(t, 0, ch.Len()) } func TestTrieStorageManager_SetCheckpointEmptyTrieRootHash(t *testing.T) { @@ -149,13 +161,15 @@ func TestTrieStorageManager_SetCheckpointEmptyTrieRootHash(t *testing.T) { rootHash := make([]byte, 32) iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } ts.SetCheckpoint(rootHash, rootHash, iteratorChannels, nil, &trieMock.MockStatistics{}) _, ok := <-iteratorChannels.LeavesChan assert.False(t, ok) - assert.Equal(t, 0, len(iteratorChannels.ErrChan)) + ch, ok := iteratorChannels.ErrChan.(errChanWithLen) + assert.True(t, ok) + assert.Equal(t, 0, ch.Len()) } func TestTrieCheckpoint_DoesNotSaveToCheckpointStorageIfNotDirty(t *testing.T) { @@ -170,7 +184,7 @@ func TestTrieCheckpoint_DoesNotSaveToCheckpointStorageIfNotDirty(t *testing.T) { iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: nil, - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } trieStorage.SetCheckpoint(rootHash, []byte{}, iteratorChannels, nil, &trieMock.MockStatistics{}) trie.WaitForOperationToComplete(trieStorage) @@ -178,7 +192,9 @@ func TestTrieCheckpoint_DoesNotSaveToCheckpointStorageIfNotDirty(t *testing.T) { val, err = trieStorage.GetFromCheckpoint(rootHash) assert.NotNil(t, err) assert.Nil(t, val) - assert.Equal(t, 0, len(iteratorChannels.ErrChan)) + ch, ok := iteratorChannels.ErrChan.(errChanWithLen) + assert.True(t, ok) + assert.Equal(t, 0, ch.Len()) } func TestTrieStorageManager_IsPruningEnabled(t *testing.T) { @@ -344,13 +360,15 @@ func TestTrieStorageManager_TakeSnapshotClosedDb(t *testing.T) { rootHash := []byte("rootHash") iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } ts.TakeSnapshot("", rootHash, rootHash, iteratorChannels, nil, &trieMock.MockStatistics{}, 0) _, ok := <-iteratorChannels.LeavesChan assert.False(t, ok) - assert.Equal(t, 0, len(iteratorChannels.ErrChan)) + ch, ok := iteratorChannels.ErrChan.(errChanWithLen) + assert.True(t, ok) + assert.Equal(t, 0, ch.Len()) } func TestTrieStorageManager_TakeSnapshotEmptyTrieRootHash(t *testing.T) { @@ -362,13 +380,15 @@ func TestTrieStorageManager_TakeSnapshotEmptyTrieRootHash(t *testing.T) { rootHash := make([]byte, 32) iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } ts.TakeSnapshot("", rootHash, rootHash, iteratorChannels, nil, &trieMock.MockStatistics{}, 0) _, ok := <-iteratorChannels.LeavesChan assert.False(t, ok) - assert.Equal(t, 0, len(iteratorChannels.ErrChan)) + ch, ok := iteratorChannels.ErrChan.(errChanWithLen) + assert.True(t, ok) + assert.Equal(t, 0, ch.Len()) } func TestTrieStorageManager_TakeSnapshotWithGetNodeFromDBError(t *testing.T) { @@ -381,15 +401,17 @@ func TestTrieStorageManager_TakeSnapshotWithGetNodeFromDBError(t *testing.T) { rootHash := []byte("rootHash") iteratorChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } missingNodesChan := make(chan []byte, 2) ts.TakeSnapshot("", rootHash, rootHash, iteratorChannels, missingNodesChan, &trieMock.MockStatistics{}, 0) _, ok := <-iteratorChannels.LeavesChan assert.False(t, ok) - require.Equal(t, 1, len(iteratorChannels.ErrChan)) - errRecovered := <-iteratorChannels.ErrChan + ch, ok := iteratorChannels.ErrChan.(errChanWithLen) + assert.True(t, ok) + assert.Equal(t, 1, ch.Len()) + errRecovered := iteratorChannels.ErrChan.ReadFromChanNonBlocking() assert.True(t, strings.Contains(errRecovered.Error(), common.GetNodeFromDBErrorString)) } @@ -430,20 +452,20 @@ func TestWriteInChanNonBlocking(t *testing.T) { t.Run("unbuffered, reader has been set up, should add", func(t *testing.T) { t.Parallel() - errChan := make(chan error) + errChannel := make(chan error) var recovered error wg := sync.WaitGroup{} wg.Add(1) // set up the consumer that will be blocked until writing is done go func() { - recovered = <-errChan + recovered = <-errChannel wg.Done() }() time.Sleep(time.Second) // allow the go routine to start - trie.WriteInChanNonBlocking(errChan, err1) + trie.WriteInChanNonBlocking(errChannel, err1) wg.Wait() assert.Equal(t, err1, recovered) @@ -453,8 +475,8 @@ func TestWriteInChanNonBlocking(t *testing.T) { chanFinish := make(chan struct{}) go func() { - errChan := make(chan error) - trie.WriteInChanNonBlocking(errChan, err1) + errChannel := make(chan error) + trie.WriteInChanNonBlocking(errChannel, err1) close(chanFinish) }() @@ -468,53 +490,54 @@ func TestWriteInChanNonBlocking(t *testing.T) { t.Run("buffered (one element), empty chan should add", func(t *testing.T) { t.Parallel() - errChan := make(chan error, 1) - trie.WriteInChanNonBlocking(errChan, err1) - require.Equal(t, 1, len(errChan)) - recovered := <-errChan + errChannel := errChan.NewErrChanWrapper() + errChannel.WriteInChanNonBlocking(err1) + + require.Equal(t, 1, errChannel.Len()) + recovered := errChannel.ReadFromChanNonBlocking() assert.Equal(t, err1, recovered) }) t.Run("buffered (1 element), full chan should not add, but should finish", func(t *testing.T) { t.Parallel() - errChan := make(chan error, 1) - trie.WriteInChanNonBlocking(errChan, err1) - trie.WriteInChanNonBlocking(errChan, err2) + errChannel := errChan.NewErrChanWrapper() + errChannel.WriteInChanNonBlocking(err1) + errChannel.WriteInChanNonBlocking(err2) - require.Equal(t, 1, len(errChan)) - recovered := <-errChan + require.Equal(t, 1, errChannel.Len()) + recovered := errChannel.ReadFromChanNonBlocking() assert.Equal(t, err1, recovered) }) t.Run("buffered (two elements), empty chan should add", func(t *testing.T) { t.Parallel() - errChan := make(chan error, 2) - trie.WriteInChanNonBlocking(errChan, err1) - require.Equal(t, 1, len(errChan)) - recovered := <-errChan + errChannel := make(chan error, 2) + trie.WriteInChanNonBlocking(errChannel, err1) + require.Equal(t, 1, len(errChannel)) + recovered := <-errChannel assert.Equal(t, err1, recovered) - trie.WriteInChanNonBlocking(errChan, err1) - trie.WriteInChanNonBlocking(errChan, err2) - require.Equal(t, 2, len(errChan)) + trie.WriteInChanNonBlocking(errChannel, err1) + trie.WriteInChanNonBlocking(errChannel, err2) + require.Equal(t, 2, len(errChannel)) - recovered = <-errChan + recovered = <-errChannel assert.Equal(t, err1, recovered) - recovered = <-errChan + recovered = <-errChannel assert.Equal(t, err2, recovered) }) t.Run("buffered (2 elements), full chan should not add, but should finish", func(t *testing.T) { t.Parallel() - errChan := make(chan error, 2) - trie.WriteInChanNonBlocking(errChan, err1) - trie.WriteInChanNonBlocking(errChan, err2) - trie.WriteInChanNonBlocking(errChan, err3) + errChannel := make(chan error, 2) + trie.WriteInChanNonBlocking(errChannel, err1) + trie.WriteInChanNonBlocking(errChannel, err2) + trie.WriteInChanNonBlocking(errChannel, err3) - require.Equal(t, 2, len(errChan)) - recovered := <-errChan + require.Equal(t, 2, len(errChannel)) + recovered := <-errChannel assert.Equal(t, err1, recovered) - recovered = <-errChan + recovered = <-errChannel assert.Equal(t, err2, recovered) }) } diff --git a/update/genesis/common.go b/update/genesis/common.go index 8c62a78ef61..023fe6d7c8d 100644 --- a/update/genesis/common.go +++ b/update/genesis/common.go @@ -34,7 +34,7 @@ func getValidatorDataFromLeaves( validators[currentShardId] = append(validators[currentShardId], validatorInfoData) } - err := common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan) + err := leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return nil, err } diff --git a/update/genesis/export.go b/update/genesis/export.go index f885c9cf55c..149f29ef6c1 100644 --- a/update/genesis/export.go +++ b/update/genesis/export.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" @@ -295,7 +296,7 @@ func (se *stateExport) exportTrie(key string, trie common.Trie) error { leavesChannels := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), - ErrChan: make(chan error, 1), + ErrChan: errChan.NewErrChanWrapper(), } err = trie.GetAllLeavesOnChannel(leavesChannels, context.Background(), rootHash, keyBuilder.NewKeyBuilder()) if err != nil { @@ -357,7 +358,7 @@ func (se *stateExport) exportDataTries( } } - err := common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan) + err := leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return err } @@ -379,7 +380,7 @@ func (se *stateExport) exportAccountLeaves( } } - err := common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan) + err := leavesChannels.ErrChan.ReadFromChanNonBlocking() if err != nil { return err } diff --git a/update/genesis/export_test.go b/update/genesis/export_test.go index 44800205606..08be4eee55c 100644 --- a/update/genesis/export_test.go +++ b/update/genesis/export_test.go @@ -294,7 +294,7 @@ func TestStateExport_ExportTrieShouldExportNodesSetupJson(t *testing.T) { go func() { channels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("test"), pacB) - channels.ErrChan <- expectedErr + channels.ErrChan.WriteInChanNonBlocking(expectedErr) close(channels.LeavesChan) }() @@ -344,7 +344,7 @@ func TestStateExport_ExportTrieShouldExportNodesSetupJson(t *testing.T) { go func() { channels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("test"), pacB) close(channels.LeavesChan) - close(channels.ErrChan) + channels.ErrChan.Close() }() return nil