Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create a new errChan struct that wraps an err chan #5068

Merged
merged 10 commits into from
Apr 3, 2023
62 changes: 62 additions & 0 deletions common/errChan/errChan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package errChan

import "sync"

type errChan struct {
ch chan error
closed bool
closeMutex sync.Mutex
}

// NewErrChan creates a new errChan
func NewErrChan() *errChan {
return &errChan{
ch: make(chan error, 1),
closed: false,
}
}

// WriteInChanNonBlocking will send the given error on the channel if the chan is not blocked
func (ec *errChan) WriteInChanNonBlocking(err error) {
select {
case ec.ch <- err:
default:
}
}

// ReadFromChanNonBlocking will read from the channel, or return nil if no error was sent on the channel
func (ec *errChan) ReadFromChanNonBlocking() error {
select {
case err := <-ec.ch:
return err
default:
return nil
}
}

// Close will close the channel
func (ec *errChan) Close() {
ec.closeMutex.Lock()
defer ec.closeMutex.Unlock()

if ec.closed {
return
}

if ec.ch == nil {
return
}

close(ec.ch)
ec.closed = true
}

// Len returns the length of the channel
func (ec *errChan) Len() int {
return len(ec.ch)
}

// IsInterfaceNil returns true if there is no value under the interface
func (ec *errChan) IsInterfaceNil() bool {
return ec == nil
}
81 changes: 81 additions & 0 deletions common/errChan/errChan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package errChan

import (
"fmt"
"testing"

"github.com/multiversx/mx-chain-core-go/core/check"
"github.com/stretchr/testify/assert"
)

func TestNewErrChan(t *testing.T) {
t.Parallel()

ec := NewErrChan()
assert.False(t, check.IfNil(ec))
assert.Equal(t, 1, cap(ec.ch))
}

func TestErrChan_WriteInChanNonBlocking(t *testing.T) {
t.Parallel()

expectedErr := fmt.Errorf("err1")
ec := NewErrChan()
ec.WriteInChanNonBlocking(expectedErr)
ec.WriteInChanNonBlocking(fmt.Errorf("err2"))
ec.WriteInChanNonBlocking(fmt.Errorf("err3"))

assert.Equal(t, 1, len(ec.ch))
assert.Equal(t, expectedErr, <-ec.ch)
assert.Equal(t, 0, len(ec.ch))
}

func TestErrChan_ReadFromChanNonBlocking(t *testing.T) {
t.Parallel()

expectedErr := fmt.Errorf("err1")
ec := NewErrChan()
ec.ch <- expectedErr

assert.Equal(t, 1, len(ec.ch))
assert.Equal(t, expectedErr, ec.ReadFromChanNonBlocking())
assert.Equal(t, 0, len(ec.ch))
assert.Nil(t, ec.ReadFromChanNonBlocking())
}

func TestErrChan_Close(t *testing.T) {
t.Parallel()

t.Run("close an already closed channel", func(t *testing.T) {
t.Parallel()

ec := NewErrChan()
ec.Close()

assert.True(t, ec.closed)
ec.Close()
})

t.Run("close a nil channel", func(t *testing.T) {
t.Parallel()

ec := NewErrChan()
ec.ch = nil
ec.Close()

assert.False(t, ec.closed)
})
}

func TestErrChan_Len(t *testing.T) {
t.Parallel()

ec := NewErrChan()
assert.Equal(t, 0, ec.Len())

ec.ch <- fmt.Errorf("err1")
assert.Equal(t, 1, ec.Len())

ec.WriteInChanNonBlocking(fmt.Errorf("err2"))
assert.Equal(t, 1, ec.Len())
}
10 changes: 9 additions & 1 deletion common/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@ import (
// TrieIteratorChannels defines the channels that are being used when iterating the trie nodes
type TrieIteratorChannels struct {
LeavesChan chan core.KeyValueHolder
ErrChan chan error
ErrChan BufferedErrChan
}

// BufferedErrChan is an interface that defines the methods for a buffered error channel
type BufferedErrChan interface {
WriteInChanNonBlocking(err error)
ReadFromChanNonBlocking() error
Close()
IsInterfaceNil() bool
}

// Trie is an interface for Merkle Trees implementations
Expand Down
5 changes: 3 additions & 2 deletions debug/process/stateExport.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/multiversx/mx-chain-core-go/core"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/common/errChan"
"github.com/multiversx/mx-chain-go/state"
)

Expand Down Expand Up @@ -66,7 +67,7 @@ func getCode(accountsDB state.AccountsAdapter, codeHash []byte) ([]byte, error)
func getData(accountsDB state.AccountsAdapter, rootHash []byte, address []byte) ([]string, error) {
leavesChannels := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChan(),
}

err := accountsDB.GetAllLeaves(leavesChannels, context.Background(), rootHash)
Expand All @@ -89,7 +90,7 @@ func getData(accountsDB state.AccountsAdapter, rootHash []byte, address []byte)
hex.EncodeToString(valWithoutSuffix)))
}

err = <-leavesChannels.ErrChan
err = leavesChannels.ErrChan.ReadFromChanNonBlocking()
if err != nil {
return nil, fmt.Errorf("%w while trying to export data on hex root hash %s, address %s",
err, hex.EncodeToString(rootHash), hex.EncodeToString(address))
Expand Down
5 changes: 3 additions & 2 deletions epochStart/metachain/systemSCs.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/multiversx/mx-chain-core-go/data/block"
"github.com/multiversx/mx-chain-core-go/marshal"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/common/errChan"
vInfo "github.com/multiversx/mx-chain-go/common/validatorInfo"
"github.com/multiversx/mx-chain-go/config"
"github.com/multiversx/mx-chain-go/epochStart"
Expand Down Expand Up @@ -1102,7 +1103,7 @@ func (s *systemSCProcessor) getArgumentsForSetOwnerFunctionality(userValidatorAc

leavesChannels := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChan(),
}
err = userValidatorAccount.DataTrie().GetAllLeavesOnChannel(leavesChannels, context.Background(), rootHash, keyBuilder.NewKeyBuilder())
if err != nil {
Expand All @@ -1125,7 +1126,7 @@ func (s *systemSCProcessor) getArgumentsForSetOwnerFunctionality(userValidatorAc
}
}

err = common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan)
err = leavesChannels.ErrChan.ReadFromChanNonBlocking()
if err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions factory/processing/processComponents.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/multiversx/mx-chain-core-go/data/outport"
nodeFactory "github.com/multiversx/mx-chain-go/cmd/node/factory"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/common/errChan"
"github.com/multiversx/mx-chain-go/config"
"github.com/multiversx/mx-chain-go/consensus"
"github.com/multiversx/mx-chain-go/dataRetriever"
Expand Down Expand Up @@ -874,7 +875,7 @@ func (pcf *processComponentsFactory) indexAndReturnGenesisAccounts() (map[string

leavesChannels := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChan(),
}
err = pcf.state.AccountsAdapter().GetAllLeaves(leavesChannels, context.Background(), rootHash)
if err != nil {
Expand All @@ -901,7 +902,7 @@ func (pcf *processComponentsFactory) indexAndReturnGenesisAccounts() (map[string
}
}

err = common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan)
err = leavesChannels.ErrChan.ReadFromChanNonBlocking()
if err != nil {
return nil, err
}
Expand Down
21 changes: 11 additions & 10 deletions integrationTests/state/stateTrieClose/stateTrieClose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/multiversx/mx-chain-core-go/core"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/common/errChan"
"github.com/multiversx/mx-chain-go/config"
"github.com/multiversx/mx-chain-go/integrationTests"
"github.com/multiversx/mx-chain-go/testscommon"
Expand Down Expand Up @@ -36,25 +37,25 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) {
rootHash, _ := tr.RootHash()
leavesChannel1 := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChan(),
}
_ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder())
time.Sleep(time.Second) // allow the go routine to start
idx, _ := gc.Snapshot()
diff := gc.DiffGoRoutines(idxInitial, idx)
assert.True(t, len(diff) <= 1) // can be 0 on a fast running host
err := common.GetErrorFromChanNonBlocking(leavesChannel1.ErrChan)
err := leavesChannel1.ErrChan.ReadFromChanNonBlocking()
assert.Nil(t, err)

leavesChannel1 = &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChan(),
}
_ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder())
idx, _ = gc.Snapshot()
diff = gc.DiffGoRoutines(idxInitial, idx)
assert.True(t, len(diff) <= 2)
err = common.GetErrorFromChanNonBlocking(leavesChannel1.ErrChan)
err = leavesChannel1.ErrChan.ReadFromChanNonBlocking()
assert.Nil(t, err)

_ = tr.Update([]byte("god"), []byte("puppy"))
Expand All @@ -63,13 +64,13 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) {
rootHash, _ = tr.RootHash()
leavesChannel1 = &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChan(),
}
_ = tr.GetAllLeavesOnChannel(leavesChannel1, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder())
idx, _ = gc.Snapshot()
diff = gc.DiffGoRoutines(idxInitial, idx)
assert.Equal(t, 3, len(diff), fmt.Sprintf("%v", diff))
err = common.GetErrorFromChanNonBlocking(leavesChannel1.ErrChan)
err = leavesChannel1.ErrChan.ReadFromChanNonBlocking()
assert.Nil(t, err)

_ = tr.Update([]byte("eggod"), []byte("cat"))
Expand All @@ -78,14 +79,14 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) {
rootHash, _ = tr.RootHash()
leavesChannel2 := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChan(),
}
_ = tr.GetAllLeavesOnChannel(leavesChannel2, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder())
time.Sleep(time.Second) // allow the go routine to start
idx, _ = gc.Snapshot()
diff = gc.DiffGoRoutines(idxInitial, idx)
assert.True(t, len(diff) <= 4)
err = common.GetErrorFromChanNonBlocking(leavesChannel2.ErrChan)
err = leavesChannel2.ErrChan.ReadFromChanNonBlocking()
assert.Nil(t, err)

for range leavesChannel1.LeavesChan {
Expand All @@ -94,7 +95,7 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) {
idx, _ = gc.Snapshot()
diff = gc.DiffGoRoutines(idxInitial, idx)
assert.True(t, len(diff) <= 3)
err = common.GetErrorFromChanNonBlocking(leavesChannel1.ErrChan)
err = leavesChannel1.ErrChan.ReadFromChanNonBlocking()
assert.Nil(t, err)

for range leavesChannel2.LeavesChan {
Expand All @@ -103,7 +104,7 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) {
idx, _ = gc.Snapshot()
diff = gc.DiffGoRoutines(idxInitial, idx)
assert.True(t, len(diff) <= 2)
err = common.GetErrorFromChanNonBlocking(leavesChannel2.ErrChan)
err = leavesChannel2.ErrChan.ReadFromChanNonBlocking()
assert.Nil(t, err)

err = tr.Close()
Expand Down
13 changes: 7 additions & 6 deletions integrationTests/state/stateTrieSync/stateTrieSync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/multiversx/mx-chain-core-go/core"
"github.com/multiversx/mx-chain-core-go/core/throttler"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/common/errChan"
"github.com/multiversx/mx-chain-go/epochStart/notifier"
"github.com/multiversx/mx-chain-go/integrationTests"
"github.com/multiversx/mx-chain-go/process/factory"
Expand Down Expand Up @@ -329,13 +330,13 @@ func testMultipleDataTriesSync(t *testing.T, numAccounts int, numDataTrieLeaves
rootHash, _ := accState.RootHash()
leavesChannel := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChan(),
}
err = accState.GetAllLeaves(leavesChannel, context.Background(), rootHash)
for range leavesChannel.LeavesChan {
}
require.Nil(t, err)
err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan)
err = leavesChannel.ErrChan.ReadFromChanNonBlocking()
require.Nil(t, err)

requesterTrie := nRequester.TrieContainer.Get([]byte(trieFactory.UserAccountTrie))
Expand All @@ -357,15 +358,15 @@ func testMultipleDataTriesSync(t *testing.T, numAccounts int, numDataTrieLeaves

leavesChannel = &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChan(),
}
err = nRequester.AccntState.GetAllLeaves(leavesChannel, context.Background(), rootHash)
assert.Nil(t, err)
numLeaves := 0
for range leavesChannel.LeavesChan {
numLeaves++
}
err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan)
err = leavesChannel.ErrChan.ReadFromChanNonBlocking()
require.Nil(t, err)
assert.Equal(t, numAccounts, numLeaves)
checkAllDataTriesAreSynced(t, numDataTrieLeaves, requesterTrie, dataTrieRootHashes)
Expand Down Expand Up @@ -559,7 +560,7 @@ func addAccountsToState(t *testing.T, numAccounts int, numDataTrieLeaves int, ac
func getNumLeaves(t *testing.T, tr common.Trie, rootHash []byte) int {
leavesChannel := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChan(),
}
err := tr.GetAllLeavesOnChannel(leavesChannel, context.Background(), rootHash, keyBuilder.NewDisabledKeyBuilder())
require.Nil(t, err)
Expand All @@ -569,7 +570,7 @@ func getNumLeaves(t *testing.T, tr common.Trie, rootHash []byte) int {
numLeaves++
}

err = common.GetErrorFromChanNonBlocking(leavesChannel.ErrChan)
err = leavesChannel.ErrChan.ReadFromChanNonBlocking()
require.Nil(t, err)

return numLeaves
Expand Down
Loading