Skip to content

Commit

Permalink
chore: Backport IAVL Concurrency fix for v0.19 (#829)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattverse authored Sep 4, 2023
1 parent 11558aa commit 063d8d4
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 68 deletions.
20 changes: 15 additions & 5 deletions iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package iavl
import (
"math/rand"
"sort"
"sync"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -35,7 +36,7 @@ func TestIterator_NewIterator_NilTree_Failure(t *testing.T) {
})

t.Run("Unsaved Fast Iterator", func(t *testing.T) {
itr := NewUnsavedFastIterator(start, end, ascending, nil, map[string]*FastNode{}, map[string]interface{}{})
itr := NewUnsavedFastIterator(start, end, ascending, nil, &sync.Map{}, &sync.Map{})
performTest(t, itr)
require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error())
})
Expand Down Expand Up @@ -296,14 +297,14 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite
require.NoError(t, err)

// No unsaved additions or removals should be present after saving
require.Equal(t, 0, len(tree.unsavedFastNodeAdditions))
require.Equal(t, 0, len(tree.unsavedFastNodeRemovals))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeAdditions))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals))

// Ensure that there are unsaved additions and removals present
secondHalfMirror := setupMirrorForIterator(t, &secondHalfConfig, tree)

require.True(t, len(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror))
require.Equal(t, 0, len(tree.unsavedFastNodeRemovals))
require.True(t, syncMapCount(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals))

// Merge the two halves
if config.ascending {
Expand All @@ -329,3 +330,12 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite
itr := NewUnsavedFastIterator(config.startIterate, config.endIterate, config.ascending, tree.ndb, tree.unsavedFastNodeAdditions, tree.unsavedFastNodeRemovals)
return itr, mirror
}

func syncMapCount(m *sync.Map) int {
count := 0
m.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}
77 changes: 46 additions & 31 deletions mutable_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ var ErrVersionDoesNotExist = errors.New("version does not exist")
//
// The inner ImmutableTree should not be used directly by callers.
type MutableTree struct {
*ImmutableTree // The current, working tree.
lastSaved *ImmutableTree // The most recently saved tree.
orphans map[string]int64 // Nodes removed by changes to working tree.
versions map[int64]bool // The previous, saved versions of the tree.
allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion)
unsavedFastNodeAdditions map[string]*FastNode // FastNodes that have not yet been saved to disk
unsavedFastNodeRemovals map[string]interface{} // FastNodes that have not yet been removed from disk
*ImmutableTree // The current, working tree.
lastSaved *ImmutableTree // The most recently saved tree.
orphans map[string]int64 // Nodes removed by changes to working tree.
versions map[int64]bool // The previous, saved versions of the tree.
allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion)
unsavedFastNodeAdditions *sync.Map // map[string]*FastNode FastNodes that have not yet been saved to disk
unsavedFastNodeRemovals *sync.Map // map[string]interface{} FastNodes that have not yet been removed from disk
ndb *nodeDB
skipFastStorageUpgrade bool // If true, the tree will work like no fast storage and always not upgrade fast storage

Expand All @@ -57,8 +57,8 @@ func NewMutableTreeWithOpts(db dbm.DB, cacheSize int, opts *Options, skipFastSto
orphans: map[string]int64{},
versions: map[int64]bool{},
allRootLoaded: false,
unsavedFastNodeAdditions: make(map[string]*FastNode),
unsavedFastNodeRemovals: make(map[string]interface{}),
unsavedFastNodeAdditions: &sync.Map{},
unsavedFastNodeRemovals: &sync.Map{},
ndb: ndb,
skipFastStorageUpgrade: skipFastStorageUpgrade,
}, nil
Expand Down Expand Up @@ -150,11 +150,11 @@ func (tree *MutableTree) Get(key []byte) ([]byte, error) {
}

if !tree.skipFastStorageUpgrade {
if fastNode, ok := tree.unsavedFastNodeAdditions[unsafeToStr(key)]; ok {
return fastNode.value, nil
if fastNode, ok := tree.unsavedFastNodeAdditions.Load(unsafeToStr(key)); ok {
return fastNode.(*FastNode).value, nil
}
// check if node was deleted
if _, ok := tree.unsavedFastNodeRemovals[string(key)]; ok {
if _, ok := tree.unsavedFastNodeRemovals.Load(string(key)); ok {
return nil, nil
}
}
Expand Down Expand Up @@ -811,8 +811,8 @@ func (tree *MutableTree) Rollback() {
}
tree.orphans = map[string]int64{}
if !tree.skipFastStorageUpgrade {
tree.unsavedFastNodeAdditions = map[string]*FastNode{}
tree.unsavedFastNodeRemovals = map[string]interface{}{}
tree.unsavedFastNodeAdditions = &sync.Map{}
tree.unsavedFastNodeRemovals = &sync.Map{}
}
}

Expand Down Expand Up @@ -931,8 +931,8 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) {
tree.lastSaved = tree.ImmutableTree.clone()
tree.orphans = map[string]int64{}
if !tree.skipFastStorageUpgrade {
tree.unsavedFastNodeAdditions = make(map[string]*FastNode)
tree.unsavedFastNodeRemovals = make(map[string]interface{})
tree.unsavedFastNodeAdditions = &sync.Map{}
tree.unsavedFastNodeRemovals = &sync.Map{}
}

hash, err := tree.Hash()
Expand All @@ -955,47 +955,62 @@ func (tree *MutableTree) saveFastNodeVersion() error {

// nolint: unused
func (tree *MutableTree) getUnsavedFastNodeAdditions() map[string]*FastNode {
return tree.unsavedFastNodeAdditions
additions := make(map[string]*FastNode)
tree.unsavedFastNodeAdditions.Range(func(key, value interface{}) bool {
additions[key.(string)] = value.(*FastNode)
return true
})
return additions
}

// getUnsavedFastNodeRemovals returns unsaved FastNodes to remove
// nolint: unused
func (tree *MutableTree) getUnsavedFastNodeRemovals() map[string]interface{} {
return tree.unsavedFastNodeRemovals
removals := make(map[string]interface{})
tree.unsavedFastNodeRemovals.Range(func(key, value interface{}) bool {
removals[key.(string)] = value
return true
})
return removals
}

// addUnsavedAddition stores an addition into the unsaved additions map
func (tree *MutableTree) addUnsavedAddition(key []byte, node *FastNode) {
skey := unsafeToStr(key)
delete(tree.unsavedFastNodeRemovals, skey)
tree.unsavedFastNodeAdditions[skey] = node
tree.unsavedFastNodeRemovals.Delete(skey)
tree.unsavedFastNodeAdditions.Store(skey, node)
}

func (tree *MutableTree) saveFastNodeAdditions() error {
keysToSort := make([]string, 0, len(tree.unsavedFastNodeAdditions))
for key := range tree.unsavedFastNodeAdditions {
keysToSort = append(keysToSort, key)
}
keysToSort := make([]string, 0)
tree.unsavedFastNodeAdditions.Range(func(k, v interface{}) bool {
keysToSort = append(keysToSort, k.(string))
return true
})
sort.Strings(keysToSort)

for _, key := range keysToSort {
if err := tree.ndb.SaveFastNode(tree.unsavedFastNodeAdditions[key]); err != nil {
val, _ := tree.unsavedFastNodeAdditions.Load(key)
if err := tree.ndb.SaveFastNode(val.(*FastNode)); err != nil {
return err
}
}
return nil
}

// addUnsavedRemoval adds a removal to the unsaved removals map
func (tree *MutableTree) addUnsavedRemoval(key []byte) {
skey := unsafeToStr(key)
delete(tree.unsavedFastNodeAdditions, skey)
tree.unsavedFastNodeRemovals[skey] = true
tree.unsavedFastNodeAdditions.Delete(skey)
tree.unsavedFastNodeRemovals.Store(skey, true)
}

func (tree *MutableTree) saveFastNodeRemovals() error {
keysToSort := make([]string, 0, len(tree.unsavedFastNodeRemovals))
for key := range tree.unsavedFastNodeRemovals {
keysToSort = append(keysToSort, key)
}
keysToSort := make([]string, 0)
tree.unsavedFastNodeRemovals.Range(func(k, v interface{}) bool {
keysToSort = append(keysToSort, k.(string))
return true
})
sort.Strings(keysToSort)

for _, key := range keysToSort {
Expand Down
71 changes: 39 additions & 32 deletions unsaved_fast_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"sort"
"sync"

dbm "github.com/tendermint/tm-db"
)
Expand All @@ -28,14 +29,14 @@ type UnsavedFastIterator struct {
fastIterator dbm.Iterator

nextUnsavedNodeIdx int
unsavedFastNodeAdditions map[string]*FastNode
unsavedFastNodeRemovals map[string]interface{}
unsavedFastNodeAdditions *sync.Map // map[string]*FastNode
unsavedFastNodeRemovals *sync.Map // map[string]interface{}
unsavedFastNodesToSort []string
}

var _ dbm.Iterator = (*UnsavedFastIterator)(nil)

func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions map[string]*FastNode, unsavedFastNodeRemovals map[string]interface{}) *UnsavedFastIterator {
func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions, unsavedFastNodeRemovals *sync.Map) *UnsavedFastIterator {
iter := &UnsavedFastIterator{
start: start,
end: end,
Expand All @@ -49,28 +50,6 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa
fastIterator: NewFastIterator(start, end, ascending, ndb),
}

// We need to ensure that we iterate over saved and unsaved state in order.
// The strategy is to sort unsaved nodes, the fast node on disk are already sorted.
// Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently.
for _, fastNode := range unsavedFastNodeAdditions {
if start != nil && bytes.Compare(fastNode.key, start) < 0 {
continue
}

if end != nil && bytes.Compare(fastNode.key, end) >= 0 {
continue
}

iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, unsafeToStr(fastNode.key))
}

sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool {
if ascending {
return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j]
}
return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j]
})

if iter.ndb == nil {
iter.err = errFastIteratorNilNdbGiven
iter.valid = false
Expand All @@ -89,7 +68,33 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa
return iter
}

// Move to the first elemenet
// We need to ensure that we iterate over saved and unsaved state in order.
// The strategy is to sort unsaved nodes, the fast node on disk are already sorted.
// Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently.
unsavedFastNodeAdditions.Range(func(k, v interface{}) bool {
fastNode := v.(*FastNode)

if start != nil && bytes.Compare(fastNode.key, start) < 0 {
return true
}

if end != nil && bytes.Compare(fastNode.key, end) >= 0 {
return true
}

iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, k.(string))

return true
})

sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool {
if ascending {
return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j]
}
return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j]
})

// Move to the first element
iter.Next()

return iter
Expand Down Expand Up @@ -134,16 +139,17 @@ func (iter *UnsavedFastIterator) Next() {

diskKeyStr := unsafeToStr(iter.fastIterator.Key())
if iter.fastIterator.Valid() && iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) {

if iter.unsavedFastNodeRemovals[diskKeyStr] != nil {
value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr)
if ok && value != nil {
// If next fast node from disk is to be removed, skip it.
iter.fastIterator.Next()
iter.Next()
return
}

nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx]
nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey]
nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey)
nextUnsavedNode := nextUnsavedNodeVal.(*FastNode)

var isUnsavedNext bool
if iter.ascending {
Expand All @@ -154,7 +160,6 @@ func (iter *UnsavedFastIterator) Next() {

if isUnsavedNext {
// Unsaved node is next

if diskKeyStr == nextUnsavedKey {
// Unsaved update prevails over saved copy so we skip the copy from disk
iter.fastIterator.Next()
Expand All @@ -176,7 +181,8 @@ func (iter *UnsavedFastIterator) Next() {

// if only nodes on disk are left, we return them
if iter.fastIterator.Valid() {
if iter.unsavedFastNodeRemovals[diskKeyStr] != nil {
value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr)
if ok && value != nil {
// If next fast node from disk is to be removed, skip it.
iter.fastIterator.Next()
iter.Next()
Expand All @@ -193,7 +199,8 @@ func (iter *UnsavedFastIterator) Next() {
// if only unsaved nodes are left, we can just iterate
if iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) {
nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx]
nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey]
nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey)
nextUnsavedNode := nextUnsavedNodeVal.(*FastNode)

iter.nextKey = nextUnsavedNode.key
iter.nextVal = nextUnsavedNode.value
Expand Down

0 comments on commit 063d8d4

Please sign in to comment.