From 063d8d4bf6ba6b7771b3a76e5d7603761adfd396 Mon Sep 17 00:00:00 2001 From: "Matt, Park" <45252226+mattverse@users.noreply.github.com> Date: Mon, 4 Sep 2023 18:20:13 +0900 Subject: [PATCH] chore: Backport IAVL Concurrency fix for v0.19 (#829) --- iterator_test.go | 20 ++++++++--- mutable_tree.go | 77 ++++++++++++++++++++++++---------------- unsaved_fast_iterator.go | 71 +++++++++++++++++++----------------- 3 files changed, 100 insertions(+), 68 deletions(-) diff --git a/iterator_test.go b/iterator_test.go index 4164e055a..ae516261e 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -3,6 +3,7 @@ package iavl import ( "math/rand" "sort" + "sync" "testing" "github.com/stretchr/testify/require" @@ -35,7 +36,7 @@ func TestIterator_NewIterator_NilTree_Failure(t *testing.T) { }) t.Run("Unsaved Fast Iterator", func(t *testing.T) { - itr := NewUnsavedFastIterator(start, end, ascending, nil, map[string]*FastNode{}, map[string]interface{}{}) + itr := NewUnsavedFastIterator(start, end, ascending, nil, &sync.Map{}, &sync.Map{}) performTest(t, itr) require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error()) }) @@ -296,14 +297,14 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite require.NoError(t, err) // No unsaved additions or removals should be present after saving - require.Equal(t, 0, len(tree.unsavedFastNodeAdditions)) - require.Equal(t, 0, len(tree.unsavedFastNodeRemovals)) + require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeAdditions)) + require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals)) // Ensure that there are unsaved additions and removals present secondHalfMirror := setupMirrorForIterator(t, &secondHalfConfig, tree) - require.True(t, len(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror)) - require.Equal(t, 0, len(tree.unsavedFastNodeRemovals)) + require.True(t, syncMapCount(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror)) + require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals)) // Merge the two halves if config.ascending { @@ -329,3 +330,12 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite itr := NewUnsavedFastIterator(config.startIterate, config.endIterate, config.ascending, tree.ndb, tree.unsavedFastNodeAdditions, tree.unsavedFastNodeRemovals) return itr, mirror } + +func syncMapCount(m *sync.Map) int { + count := 0 + m.Range(func(_, _ interface{}) bool { + count++ + return true + }) + return count +} diff --git a/mutable_tree.go b/mutable_tree.go index f5856125b..c57ddf190 100644 --- a/mutable_tree.go +++ b/mutable_tree.go @@ -28,13 +28,13 @@ var ErrVersionDoesNotExist = errors.New("version does not exist") // // The inner ImmutableTree should not be used directly by callers. type MutableTree struct { - *ImmutableTree // The current, working tree. - lastSaved *ImmutableTree // The most recently saved tree. - orphans map[string]int64 // Nodes removed by changes to working tree. - versions map[int64]bool // The previous, saved versions of the tree. - allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion) - unsavedFastNodeAdditions map[string]*FastNode // FastNodes that have not yet been saved to disk - unsavedFastNodeRemovals map[string]interface{} // FastNodes that have not yet been removed from disk + *ImmutableTree // The current, working tree. + lastSaved *ImmutableTree // The most recently saved tree. + orphans map[string]int64 // Nodes removed by changes to working tree. + versions map[int64]bool // The previous, saved versions of the tree. + allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion) + unsavedFastNodeAdditions *sync.Map // map[string]*FastNode FastNodes that have not yet been saved to disk + unsavedFastNodeRemovals *sync.Map // map[string]interface{} FastNodes that have not yet been removed from disk ndb *nodeDB skipFastStorageUpgrade bool // If true, the tree will work like no fast storage and always not upgrade fast storage @@ -57,8 +57,8 @@ func NewMutableTreeWithOpts(db dbm.DB, cacheSize int, opts *Options, skipFastSto orphans: map[string]int64{}, versions: map[int64]bool{}, allRootLoaded: false, - unsavedFastNodeAdditions: make(map[string]*FastNode), - unsavedFastNodeRemovals: make(map[string]interface{}), + unsavedFastNodeAdditions: &sync.Map{}, + unsavedFastNodeRemovals: &sync.Map{}, ndb: ndb, skipFastStorageUpgrade: skipFastStorageUpgrade, }, nil @@ -150,11 +150,11 @@ func (tree *MutableTree) Get(key []byte) ([]byte, error) { } if !tree.skipFastStorageUpgrade { - if fastNode, ok := tree.unsavedFastNodeAdditions[unsafeToStr(key)]; ok { - return fastNode.value, nil + if fastNode, ok := tree.unsavedFastNodeAdditions.Load(unsafeToStr(key)); ok { + return fastNode.(*FastNode).value, nil } // check if node was deleted - if _, ok := tree.unsavedFastNodeRemovals[string(key)]; ok { + if _, ok := tree.unsavedFastNodeRemovals.Load(string(key)); ok { return nil, nil } } @@ -811,8 +811,8 @@ func (tree *MutableTree) Rollback() { } tree.orphans = map[string]int64{} if !tree.skipFastStorageUpgrade { - tree.unsavedFastNodeAdditions = map[string]*FastNode{} - tree.unsavedFastNodeRemovals = map[string]interface{}{} + tree.unsavedFastNodeAdditions = &sync.Map{} + tree.unsavedFastNodeRemovals = &sync.Map{} } } @@ -931,8 +931,8 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) { tree.lastSaved = tree.ImmutableTree.clone() tree.orphans = map[string]int64{} if !tree.skipFastStorageUpgrade { - tree.unsavedFastNodeAdditions = make(map[string]*FastNode) - tree.unsavedFastNodeRemovals = make(map[string]interface{}) + tree.unsavedFastNodeAdditions = &sync.Map{} + tree.unsavedFastNodeRemovals = &sync.Map{} } hash, err := tree.Hash() @@ -955,47 +955,62 @@ func (tree *MutableTree) saveFastNodeVersion() error { // nolint: unused func (tree *MutableTree) getUnsavedFastNodeAdditions() map[string]*FastNode { - return tree.unsavedFastNodeAdditions + additions := make(map[string]*FastNode) + tree.unsavedFastNodeAdditions.Range(func(key, value interface{}) bool { + additions[key.(string)] = value.(*FastNode) + return true + }) + return additions } // getUnsavedFastNodeRemovals returns unsaved FastNodes to remove // nolint: unused func (tree *MutableTree) getUnsavedFastNodeRemovals() map[string]interface{} { - return tree.unsavedFastNodeRemovals + removals := make(map[string]interface{}) + tree.unsavedFastNodeRemovals.Range(func(key, value interface{}) bool { + removals[key.(string)] = value + return true + }) + return removals } +// addUnsavedAddition stores an addition into the unsaved additions map func (tree *MutableTree) addUnsavedAddition(key []byte, node *FastNode) { skey := unsafeToStr(key) - delete(tree.unsavedFastNodeRemovals, skey) - tree.unsavedFastNodeAdditions[skey] = node + tree.unsavedFastNodeRemovals.Delete(skey) + tree.unsavedFastNodeAdditions.Store(skey, node) } func (tree *MutableTree) saveFastNodeAdditions() error { - keysToSort := make([]string, 0, len(tree.unsavedFastNodeAdditions)) - for key := range tree.unsavedFastNodeAdditions { - keysToSort = append(keysToSort, key) - } + keysToSort := make([]string, 0) + tree.unsavedFastNodeAdditions.Range(func(k, v interface{}) bool { + keysToSort = append(keysToSort, k.(string)) + return true + }) sort.Strings(keysToSort) for _, key := range keysToSort { - if err := tree.ndb.SaveFastNode(tree.unsavedFastNodeAdditions[key]); err != nil { + val, _ := tree.unsavedFastNodeAdditions.Load(key) + if err := tree.ndb.SaveFastNode(val.(*FastNode)); err != nil { return err } } return nil } +// addUnsavedRemoval adds a removal to the unsaved removals map func (tree *MutableTree) addUnsavedRemoval(key []byte) { skey := unsafeToStr(key) - delete(tree.unsavedFastNodeAdditions, skey) - tree.unsavedFastNodeRemovals[skey] = true + tree.unsavedFastNodeAdditions.Delete(skey) + tree.unsavedFastNodeRemovals.Store(skey, true) } func (tree *MutableTree) saveFastNodeRemovals() error { - keysToSort := make([]string, 0, len(tree.unsavedFastNodeRemovals)) - for key := range tree.unsavedFastNodeRemovals { - keysToSort = append(keysToSort, key) - } + keysToSort := make([]string, 0) + tree.unsavedFastNodeRemovals.Range(func(k, v interface{}) bool { + keysToSort = append(keysToSort, k.(string)) + return true + }) sort.Strings(keysToSort) for _, key := range keysToSort { diff --git a/unsaved_fast_iterator.go b/unsaved_fast_iterator.go index cbbff85fe..a5fe41ea9 100644 --- a/unsaved_fast_iterator.go +++ b/unsaved_fast_iterator.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "sort" + "sync" dbm "github.com/tendermint/tm-db" ) @@ -28,14 +29,14 @@ type UnsavedFastIterator struct { fastIterator dbm.Iterator nextUnsavedNodeIdx int - unsavedFastNodeAdditions map[string]*FastNode - unsavedFastNodeRemovals map[string]interface{} + unsavedFastNodeAdditions *sync.Map // map[string]*FastNode + unsavedFastNodeRemovals *sync.Map // map[string]interface{} unsavedFastNodesToSort []string } var _ dbm.Iterator = (*UnsavedFastIterator)(nil) -func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions map[string]*FastNode, unsavedFastNodeRemovals map[string]interface{}) *UnsavedFastIterator { +func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions, unsavedFastNodeRemovals *sync.Map) *UnsavedFastIterator { iter := &UnsavedFastIterator{ start: start, end: end, @@ -49,28 +50,6 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa fastIterator: NewFastIterator(start, end, ascending, ndb), } - // We need to ensure that we iterate over saved and unsaved state in order. - // The strategy is to sort unsaved nodes, the fast node on disk are already sorted. - // Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently. - for _, fastNode := range unsavedFastNodeAdditions { - if start != nil && bytes.Compare(fastNode.key, start) < 0 { - continue - } - - if end != nil && bytes.Compare(fastNode.key, end) >= 0 { - continue - } - - iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, unsafeToStr(fastNode.key)) - } - - sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool { - if ascending { - return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j] - } - return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j] - }) - if iter.ndb == nil { iter.err = errFastIteratorNilNdbGiven iter.valid = false @@ -89,7 +68,33 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa return iter } - // Move to the first elemenet + // We need to ensure that we iterate over saved and unsaved state in order. + // The strategy is to sort unsaved nodes, the fast node on disk are already sorted. + // Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently. + unsavedFastNodeAdditions.Range(func(k, v interface{}) bool { + fastNode := v.(*FastNode) + + if start != nil && bytes.Compare(fastNode.key, start) < 0 { + return true + } + + if end != nil && bytes.Compare(fastNode.key, end) >= 0 { + return true + } + + iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, k.(string)) + + return true + }) + + sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool { + if ascending { + return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j] + } + return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j] + }) + + // Move to the first element iter.Next() return iter @@ -134,8 +139,8 @@ func (iter *UnsavedFastIterator) Next() { diskKeyStr := unsafeToStr(iter.fastIterator.Key()) if iter.fastIterator.Valid() && iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) { - - if iter.unsavedFastNodeRemovals[diskKeyStr] != nil { + value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr) + if ok && value != nil { // If next fast node from disk is to be removed, skip it. iter.fastIterator.Next() iter.Next() @@ -143,7 +148,8 @@ func (iter *UnsavedFastIterator) Next() { } nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx] - nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey] + nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey) + nextUnsavedNode := nextUnsavedNodeVal.(*FastNode) var isUnsavedNext bool if iter.ascending { @@ -154,7 +160,6 @@ func (iter *UnsavedFastIterator) Next() { if isUnsavedNext { // Unsaved node is next - if diskKeyStr == nextUnsavedKey { // Unsaved update prevails over saved copy so we skip the copy from disk iter.fastIterator.Next() @@ -176,7 +181,8 @@ func (iter *UnsavedFastIterator) Next() { // if only nodes on disk are left, we return them if iter.fastIterator.Valid() { - if iter.unsavedFastNodeRemovals[diskKeyStr] != nil { + value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr) + if ok && value != nil { // If next fast node from disk is to be removed, skip it. iter.fastIterator.Next() iter.Next() @@ -193,7 +199,8 @@ func (iter *UnsavedFastIterator) Next() { // if only unsaved nodes are left, we can just iterate if iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) { nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx] - nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey] + nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey) + nextUnsavedNode := nextUnsavedNodeVal.(*FastNode) iter.nextKey = nextUnsavedNode.key iter.nextVal = nextUnsavedNode.value