Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support concurrency for IAVL and fix Racing conditions #805

Merged
merged 6 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package iavl
import (
"math/rand"
"sort"
"sync"
"testing"

log "cosmossdk.io/log"
dbm "github.com/cosmos/cosmos-db"
"github.com/cosmos/iavl/fastnode"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -37,7 +37,7 @@ func TestIterator_NewIterator_NilTree_Failure(t *testing.T) {
})

t.Run("Unsaved Fast Iterator", func(t *testing.T) {
itr := NewUnsavedFastIterator(start, end, ascending, nil, map[string]*fastnode.Node{}, map[string]interface{}{})
itr := NewUnsavedFastIterator(start, end, ascending, nil, &sync.Map{}, &sync.Map{})
performTest(t, itr)
require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error())
})
Expand Down Expand Up @@ -292,14 +292,14 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite
require.NoError(t, err)

// No unsaved additions or removals should be present after saving
require.Equal(t, 0, len(tree.unsavedFastNodeAdditions))
require.Equal(t, 0, len(tree.unsavedFastNodeRemovals))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeAdditions))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals))

// Ensure that there are unsaved additions and removals present
secondHalfMirror := setupMirrorForIterator(t, &secondHalfConfig, tree)

require.True(t, len(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror))
require.Equal(t, 0, len(tree.unsavedFastNodeRemovals))
require.True(t, syncMapCount(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals))

// Merge the two halves
if config.ascending {
Expand Down Expand Up @@ -371,3 +371,12 @@ func TestNodeIterator_WithEmptyRoot(t *testing.T) {
require.NoError(t, err)
require.False(t, itr.Valid())
}

func syncMapCount(m *sync.Map) int {
count := 0
m.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}
73 changes: 45 additions & 28 deletions mutable_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ type Option func(*Options)
type MutableTree struct {
logger log.Logger

*ImmutableTree // The current, working tree.
lastSaved *ImmutableTree // The most recently saved tree.
unsavedFastNodeAdditions map[string]*fastnode.Node // FastNodes that have not yet been saved to disk
unsavedFastNodeRemovals map[string]interface{} // FastNodes that have not yet been removed from disk
*ImmutableTree // The current, working tree.
lastSaved *ImmutableTree // The most recently saved tree.
unsavedFastNodeAdditions *sync.Map // map[string]*FastNode FastNodes that have not yet been saved to disk
unsavedFastNodeRemovals *sync.Map // map[string]interface{} FastNodes that have not yet been removed from disk
ndb *nodeDB
skipFastStorageUpgrade bool // If true, the tree will work like no fast storage and always not upgrade fast storage

Expand All @@ -62,8 +62,8 @@ func NewMutableTree(db dbm.DB, cacheSize int, skipFastStorageUpgrade bool, lg lo
logger: lg,
ImmutableTree: head,
lastSaved: head.clone(),
unsavedFastNodeAdditions: make(map[string]*fastnode.Node),
unsavedFastNodeRemovals: make(map[string]interface{}),
unsavedFastNodeAdditions: &sync.Map{},
unsavedFastNodeRemovals: &sync.Map{},
ndb: ndb,
skipFastStorageUpgrade: skipFastStorageUpgrade,
}
Expand Down Expand Up @@ -176,11 +176,11 @@ func (tree *MutableTree) Get(key []byte) ([]byte, error) {
}

if !tree.skipFastStorageUpgrade {
if fastNode, ok := tree.unsavedFastNodeAdditions[string(key)]; ok {
return fastNode.GetValue(), nil
if fastNode, ok := tree.unsavedFastNodeAdditions.Load(ibytes.UnsafeBytesToStr(key)); ok {
return fastNode.(*fastnode.Node).GetValue(), nil
}
// check if node was deleted
if _, ok := tree.unsavedFastNodeRemovals[string(key)]; ok {
if _, ok := tree.unsavedFastNodeRemovals.Load(string(key)); ok {
return nil, nil
}
}
Expand Down Expand Up @@ -659,8 +659,8 @@ func (tree *MutableTree) Rollback() {
}
}
if !tree.skipFastStorageUpgrade {
tree.unsavedFastNodeAdditions = map[string]*fastnode.Node{}
tree.unsavedFastNodeRemovals = map[string]interface{}{}
tree.unsavedFastNodeAdditions = &sync.Map{}
tree.unsavedFastNodeRemovals = &sync.Map{}
}
}

Expand Down Expand Up @@ -778,8 +778,8 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) {
tree.ImmutableTree = tree.ImmutableTree.clone()
tree.lastSaved = tree.ImmutableTree.clone()
if !tree.skipFastStorageUpgrade {
tree.unsavedFastNodeAdditions = make(map[string]*fastnode.Node)
tree.unsavedFastNodeRemovals = make(map[string]interface{})
tree.unsavedFastNodeAdditions = &sync.Map{}
tree.unsavedFastNodeRemovals = &sync.Map{}
}

hash := tree.Hash()
Expand All @@ -797,30 +797,45 @@ func (tree *MutableTree) saveFastNodeVersion(isGenesis bool) error {
return tree.ndb.setFastStorageVersionToBatch()
}

// nolint: unused
func (tree *MutableTree) getUnsavedFastNodeAdditions() map[string]*fastnode.Node {
return tree.unsavedFastNodeAdditions
additions := make(map[string]*fastnode.Node)
tree.unsavedFastNodeAdditions.Range(func(key, value interface{}) bool {
additions[key.(string)] = value.(*fastnode.Node)
return true
})
return additions
}

// getUnsavedFastNodeRemovals returns unsaved FastNodes to remove

func (tree *MutableTree) getUnsavedFastNodeRemovals() map[string]interface{} {
return tree.unsavedFastNodeRemovals
removals := make(map[string]interface{})
tree.unsavedFastNodeRemovals.Range(func(key, value interface{}) bool {
removals[key.(string)] = value
return true
})
return removals
}

// addUnsavedAddition stores an addition into the unsaved additions map
func (tree *MutableTree) addUnsavedAddition(key []byte, node *fastnode.Node) {
delete(tree.unsavedFastNodeRemovals, ibytes.UnsafeBytesToStr(key))
tree.unsavedFastNodeAdditions[string(key)] = node
skey := ibytes.UnsafeBytesToStr(key)
tree.unsavedFastNodeRemovals.Delete(skey)
tree.unsavedFastNodeAdditions.Store(skey, node)
}

func (tree *MutableTree) saveFastNodeAdditions(batchCommmit bool) error {
keysToSort := make([]string, 0, len(tree.unsavedFastNodeAdditions))
for key := range tree.unsavedFastNodeAdditions {
keysToSort = append(keysToSort, key)
}
keysToSort := make([]string, 0)
tree.unsavedFastNodeAdditions.Range(func(k, v interface{}) bool {
keysToSort = append(keysToSort, k.(string))
return true
})
sort.Strings(keysToSort)

for _, key := range keysToSort {
if err := tree.ndb.SaveFastNode(tree.unsavedFastNodeAdditions[key]); err != nil {
val, _ := tree.unsavedFastNodeAdditions.Load(key)
if err := tree.ndb.SaveFastNode(val.(*fastnode.Node)); err != nil {
return err
}
if batchCommmit {
Expand All @@ -832,17 +847,19 @@ func (tree *MutableTree) saveFastNodeAdditions(batchCommmit bool) error {
return nil
}

// addUnsavedRemoval adds a removal to the unsaved removals map
func (tree *MutableTree) addUnsavedRemoval(key []byte) {
skey := ibytes.UnsafeBytesToStr(key)
delete(tree.unsavedFastNodeAdditions, skey)
tree.unsavedFastNodeRemovals[skey] = true
tree.unsavedFastNodeAdditions.Delete(skey)
tree.unsavedFastNodeRemovals.Store(skey, true)
}

func (tree *MutableTree) saveFastNodeRemovals() error {
keysToSort := make([]string, 0, len(tree.unsavedFastNodeRemovals))
for key := range tree.unsavedFastNodeRemovals {
keysToSort = append(keysToSort, key)
}
keysToSort := make([]string, 0)
tree.unsavedFastNodeRemovals.Range(func(k, v interface{}) bool {
keysToSort = append(keysToSort, k.(string))
return true
})
sort.Strings(keysToSort)

for _, key := range keysToSort {
Expand Down
86 changes: 47 additions & 39 deletions unsaved_fast_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"bytes"
"errors"
"sort"
"sync"

dbm "github.com/cosmos/cosmos-db"

"github.com/cosmos/iavl/fastnode"
ibytes "github.com/cosmos/iavl/internal/bytes"
)

var (
Expand All @@ -29,14 +32,14 @@ type UnsavedFastIterator struct {
fastIterator dbm.Iterator

nextUnsavedNodeIdx int
unsavedFastNodeAdditions map[string]*fastnode.Node
unsavedFastNodeRemovals map[string]interface{}
unsavedFastNodesToSort [][]byte
unsavedFastNodeAdditions *sync.Map // map[string]*FastNode
unsavedFastNodeRemovals *sync.Map // map[string]interface{}
unsavedFastNodesToSort []string
}

var _ dbm.Iterator = (*UnsavedFastIterator)(nil)

func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions map[string]*fastnode.Node, unsavedFastNodeRemovals map[string]interface{}) *UnsavedFastIterator {
func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions, unsavedFastNodeRemovals *sync.Map) *UnsavedFastIterator {
iter := &UnsavedFastIterator{
start: start,
end: end,
Expand All @@ -50,29 +53,6 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa
fastIterator: NewFastIterator(start, end, ascending, ndb),
}

// We need to ensure that we iterate over saved and unsaved state in order.
// The strategy is to sort unsaved nodes, the fast node on disk are already sorted.
// Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently.
for _, fastNode := range unsavedFastNodeAdditions {
if start != nil && bytes.Compare(fastNode.GetKey(), start) < 0 {
continue
}

if end != nil && bytes.Compare(fastNode.GetKey(), end) >= 0 {
continue
}

iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, fastNode.GetKey())
}

sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool {
cmp := bytes.Compare(iter.unsavedFastNodesToSort[i], iter.unsavedFastNodesToSort[j])
if ascending {
return cmp < 0
}
return cmp > 0
})

if iter.ndb == nil {
iter.err = errFastIteratorNilNdbGiven
iter.valid = false
Expand All @@ -90,8 +70,34 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa
iter.valid = false
return iter
}
// We need to ensure that we iterate over saved and unsaved state in order.
// The strategy is to sort unsaved nodes, the fast node on disk are already sorted.
// Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently.
unsavedFastNodeAdditions.Range(func(k, v interface{}) bool {
fastNode := v.(*fastnode.Node)

if start != nil && bytes.Compare(fastNode.GetKey(), start) < 0 {
return true
}

if end != nil && bytes.Compare(fastNode.GetKey(), end) >= 0 {
return true
}

// convert key to bytes. Type conversion failure should not happen in practice
iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, k.(string))

// Move to the first elemenet
return true
})

sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool {
if ascending {
return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j]
}
return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j]
})

// Move to the first element
iter.Next()

return iter
Expand Down Expand Up @@ -134,31 +140,31 @@ func (iter *UnsavedFastIterator) Next() {
return
}

diskKeyStr := iter.fastIterator.Key()
diskKey := iter.fastIterator.Key()
diskKeyStr := ibytes.UnsafeBytesToStr(diskKey)
if iter.fastIterator.Valid() && iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) {

if iter.unsavedFastNodeRemovals[string(diskKeyStr)] != nil {
value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr)
if ok && value != nil {
// If next fast node from disk is to be removed, skip it.
iter.fastIterator.Next()
iter.Next()
return
}

nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx]
nextUnsavedNode := iter.unsavedFastNodeAdditions[string(nextUnsavedKey)]
nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably check whether the second return was set to true before proceeding

nextUnsavedNode := nextUnsavedNodeVal.(*fastnode.Node)

var isUnsavedNext bool
cmp := bytes.Compare(diskKeyStr, nextUnsavedKey)
if iter.ascending {
isUnsavedNext = cmp >= 0
isUnsavedNext = diskKeyStr >= nextUnsavedKey
} else {
isUnsavedNext = cmp <= 0
isUnsavedNext = diskKeyStr <= nextUnsavedKey
}

if isUnsavedNext {
// Unsaved node is next

if cmp == 0 {
if diskKeyStr == nextUnsavedKey {
// Unsaved update prevails over saved copy so we skip the copy from disk
iter.fastIterator.Next()
}
Expand All @@ -179,7 +185,8 @@ func (iter *UnsavedFastIterator) Next() {

// if only nodes on disk are left, we return them
if iter.fastIterator.Valid() {
if iter.unsavedFastNodeRemovals[string(diskKeyStr)] != nil {
value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr)
if ok && value != nil {
// If next fast node from disk is to be removed, skip it.
iter.fastIterator.Next()
iter.Next()
Expand All @@ -196,7 +203,8 @@ func (iter *UnsavedFastIterator) Next() {
// if only unsaved nodes are left, we can just iterate
if iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) {
nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx]
nextUnsavedNode := iter.unsavedFastNodeAdditions[string(nextUnsavedKey)]
nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worthwhile checking second return here as well

nextUnsavedNode := nextUnsavedNodeVal.(*fastnode.Node)

iter.nextKey = nextUnsavedNode.GetKey()
iter.nextVal = nextUnsavedNode.GetValue()
Expand Down