Skip to content

Commit

Permalink
Merge 3c0f4df into b94c62d
Browse files Browse the repository at this point in the history
  • Loading branch information
BeniaminDrasovean authored Apr 3, 2023
2 parents b94c62d + 3c0f4df commit 6a1b1c8
Show file tree
Hide file tree
Showing 41 changed files with 488 additions and 266 deletions.
10 changes: 0 additions & 10 deletions common/channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,3 @@ func GetClosedUnbufferedChannel() chan struct{} {

return ch
}

// GetErrorFromChanNonBlocking will get the error from channel
func GetErrorFromChanNonBlocking(errChan chan error) error {
select {
case err := <-errChan:
return err
default:
return nil
}
}
57 changes: 0 additions & 57 deletions common/channels_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package common

import (
"errors"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -39,57 +36,3 @@ func didTriggerHappen(ch chan struct{}) bool {
return false
}
}

func TestErrFromChannel(t *testing.T) {
t.Parallel()

t.Run("empty channel, should return nil", func(t *testing.T) {
t.Parallel()

t.Run("unbuffered chan", func(t *testing.T) {
t.Parallel()

errChan := make(chan error)
assert.Nil(t, GetErrorFromChanNonBlocking(errChan))
})

t.Run("buffered chan", func(t *testing.T) {
t.Parallel()

errChan := make(chan error, 1)
assert.Nil(t, GetErrorFromChanNonBlocking(errChan))
})
})

t.Run("non empty channel, should return error", func(t *testing.T) {
t.Parallel()

t.Run("unbuffered chan", func(t *testing.T) {
t.Parallel()

expectedErr := errors.New("expected error")
errChan := make(chan error)
go func() {
errChan <- expectedErr
}()

time.Sleep(time.Second) // allow the go routine to start

assert.Equal(t, expectedErr, GetErrorFromChanNonBlocking(errChan))
})

t.Run("buffered chan", func(t *testing.T) {
t.Parallel()

for i := 1; i < 10; i++ {
errChan := make(chan error, i)
expectedErr := errors.New("expected error")
for j := 0; j < i; j++ {
errChan <- expectedErr
}

assert.Equal(t, expectedErr, GetErrorFromChanNonBlocking(errChan))
}
})
})
}
69 changes: 69 additions & 0 deletions common/errChan/errChan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package errChan

import "sync"

type errChanWrapper struct {
ch chan error
closed bool
closeMutex sync.RWMutex
}

// NewErrChanWrapper creates a new errChanWrapper
func NewErrChanWrapper() *errChanWrapper {
return &errChanWrapper{
ch: make(chan error, 1),
closed: false,
}
}

// WriteInChanNonBlocking will send the given error on the channel if the chan is not blocked
func (ec *errChanWrapper) WriteInChanNonBlocking(err error) {
ec.closeMutex.RLock()
defer ec.closeMutex.RUnlock()

if ec.closed {
return
}

select {
case ec.ch <- err:
default:
}
}

// ReadFromChanNonBlocking will read from the channel, or return nil if no error was sent on the channel
func (ec *errChanWrapper) ReadFromChanNonBlocking() error {
select {
case err := <-ec.ch:
return err
default:
return nil
}
}

// Close will close the channel
func (ec *errChanWrapper) Close() {
ec.closeMutex.Lock()
defer ec.closeMutex.Unlock()

if ec.closed {
return
}

if ec.ch == nil {
return
}

close(ec.ch)
ec.closed = true
}

// Len returns the length of the channel
func (ec *errChanWrapper) Len() int {
return len(ec.ch)
}

// IsInterfaceNil returns true if there is no value under the interface
func (ec *errChanWrapper) IsInterfaceNil() bool {
return ec == nil
}
136 changes: 136 additions & 0 deletions common/errChan/errChan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package errChan

import (
"fmt"
"sync"
"testing"

"github.com/multiversx/mx-chain-core-go/core/check"
"github.com/stretchr/testify/assert"
)

func TestNewErrChan(t *testing.T) {
t.Parallel()

ec := NewErrChanWrapper()
assert.False(t, check.IfNil(ec))
assert.Equal(t, 1, cap(ec.ch))
}

func TestErrChan_WriteInChanNonBlocking(t *testing.T) {
t.Parallel()

t.Run("write in a nil channel", func(t *testing.T) {
t.Parallel()

ec := NewErrChanWrapper()
ec.ch = nil
ec.WriteInChanNonBlocking(fmt.Errorf("err1"))

assert.Equal(t, 0, len(ec.ch))
})

t.Run("write in a closed channel", func(t *testing.T) {
t.Parallel()

ec := NewErrChanWrapper()
ec.Close()
ec.WriteInChanNonBlocking(fmt.Errorf("err1"))

assert.Equal(t, 0, len(ec.ch))
})

t.Run("should work", func(t *testing.T) {
expectedErr := fmt.Errorf("err1")
ec := NewErrChanWrapper()
ec.WriteInChanNonBlocking(expectedErr)
ec.WriteInChanNonBlocking(fmt.Errorf("err2"))
ec.WriteInChanNonBlocking(fmt.Errorf("err3"))

assert.Equal(t, 1, len(ec.ch))
assert.Equal(t, expectedErr, <-ec.ch)
assert.Equal(t, 0, len(ec.ch))
})
}

func TestErrChan_ReadFromChanNonBlocking(t *testing.T) {
t.Parallel()

expectedErr := fmt.Errorf("err1")
ec := NewErrChanWrapper()
ec.ch <- expectedErr

assert.Equal(t, 1, len(ec.ch))
assert.Equal(t, expectedErr, ec.ReadFromChanNonBlocking())
assert.Equal(t, 0, len(ec.ch))
assert.Nil(t, ec.ReadFromChanNonBlocking())
}

func TestErrChan_Close(t *testing.T) {
t.Parallel()

t.Run("close an already closed channel", func(t *testing.T) {
t.Parallel()

ec := NewErrChanWrapper()
ec.Close()

assert.True(t, ec.closed)
ec.Close()
})

t.Run("close a nil channel", func(t *testing.T) {
t.Parallel()

ec := NewErrChanWrapper()
ec.ch = nil
ec.Close()

assert.False(t, ec.closed)
})
}

func TestErrChan_Len(t *testing.T) {
t.Parallel()

ec := NewErrChanWrapper()
assert.Equal(t, 0, ec.Len())

ec.ch <- fmt.Errorf("err1")
assert.Equal(t, 1, ec.Len())

ec.WriteInChanNonBlocking(fmt.Errorf("err2"))
assert.Equal(t, 1, ec.Len())
}

func TestErrChan_ConcurrentOperations(t *testing.T) {
t.Parallel()

ec := NewErrChanWrapper()
numOperations := 1000
numMethods := 2
wg := sync.WaitGroup{}
wg.Add(numOperations)
for i := 0; i < numOperations; i++ {
go func(idx int) {

if idx == numOperations-100 {
ec.Close()
}

operation := idx % numMethods
switch operation {
case 0:
ec.WriteInChanNonBlocking(fmt.Errorf("err"))
case 1:
_ = ec.ReadFromChanNonBlocking()
default:
assert.Fail(t, "invalid numMethods")
}

wg.Done()
}(i)
}

wg.Wait()
}
10 changes: 9 additions & 1 deletion common/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@ import (
// TrieIteratorChannels defines the channels that are being used when iterating the trie nodes
type TrieIteratorChannels struct {
LeavesChan chan core.KeyValueHolder
ErrChan chan error
ErrChan BufferedErrChan
}

// BufferedErrChan is an interface that defines the methods for a buffered error channel
type BufferedErrChan interface {
WriteInChanNonBlocking(err error)
ReadFromChanNonBlocking() error
Close()
IsInterfaceNil() bool
}

// Trie is an interface for Merkle Trees implementations
Expand Down
5 changes: 3 additions & 2 deletions debug/process/stateExport.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/multiversx/mx-chain-core-go/core"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/common/errChan"
"github.com/multiversx/mx-chain-go/state"
)

Expand Down Expand Up @@ -66,7 +67,7 @@ func getCode(accountsDB state.AccountsAdapter, codeHash []byte) ([]byte, error)
func getData(accountsDB state.AccountsAdapter, rootHash []byte, address []byte) ([]string, error) {
leavesChannels := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChanWrapper(),
}

err := accountsDB.GetAllLeaves(leavesChannels, context.Background(), rootHash)
Expand All @@ -89,7 +90,7 @@ func getData(accountsDB state.AccountsAdapter, rootHash []byte, address []byte)
hex.EncodeToString(valWithoutSuffix)))
}

err = <-leavesChannels.ErrChan
err = leavesChannels.ErrChan.ReadFromChanNonBlocking()
if err != nil {
return nil, fmt.Errorf("%w while trying to export data on hex root hash %s, address %s",
err, hex.EncodeToString(rootHash), hex.EncodeToString(address))
Expand Down
5 changes: 3 additions & 2 deletions epochStart/metachain/systemSCs.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/multiversx/mx-chain-core-go/data/block"
"github.com/multiversx/mx-chain-core-go/marshal"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/common/errChan"
vInfo "github.com/multiversx/mx-chain-go/common/validatorInfo"
"github.com/multiversx/mx-chain-go/config"
"github.com/multiversx/mx-chain-go/epochStart"
Expand Down Expand Up @@ -1102,7 +1103,7 @@ func (s *systemSCProcessor) getArgumentsForSetOwnerFunctionality(userValidatorAc

leavesChannels := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChanWrapper(),
}
err = userValidatorAccount.DataTrie().GetAllLeavesOnChannel(leavesChannels, context.Background(), rootHash, keyBuilder.NewKeyBuilder())
if err != nil {
Expand All @@ -1125,7 +1126,7 @@ func (s *systemSCProcessor) getArgumentsForSetOwnerFunctionality(userValidatorAc
}
}

err = common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan)
err = leavesChannels.ErrChan.ReadFromChanNonBlocking()
if err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions factory/processing/processComponents.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/multiversx/mx-chain-core-go/data/outport"
nodeFactory "github.com/multiversx/mx-chain-go/cmd/node/factory"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/common/errChan"
"github.com/multiversx/mx-chain-go/config"
"github.com/multiversx/mx-chain-go/consensus"
"github.com/multiversx/mx-chain-go/dataRetriever"
Expand Down Expand Up @@ -874,7 +875,7 @@ func (pcf *processComponentsFactory) indexAndReturnGenesisAccounts() (map[string

leavesChannels := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity),
ErrChan: make(chan error, 1),
ErrChan: errChan.NewErrChanWrapper(),
}
err = pcf.state.AccountsAdapter().GetAllLeaves(leavesChannels, context.Background(), rootHash)
if err != nil {
Expand All @@ -901,7 +902,7 @@ func (pcf *processComponentsFactory) indexAndReturnGenesisAccounts() (map[string
}
}

err = common.GetErrorFromChanNonBlocking(leavesChannels.ErrChan)
err = leavesChannels.ErrChan.ReadFromChanNonBlocking()
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 6a1b1c8

Please sign in to comment.