Skip to content

Commit

Permalink
[[FIX]] fix merge iterator reverse list(#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
bysomeone authored and 33cn committed Jan 19, 2022
1 parent 89ce5be commit fd5b2dd
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 48 deletions.
70 changes: 22 additions & 48 deletions common/db/merge_iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (i *mergedIterator) Rewind() bool {
}
}
i.dir = dirSOI
return i.next(false)
return i.selectKey()
}

func (i *mergedIterator) Seek(key []byte) bool {
Expand All @@ -97,35 +97,33 @@ func (i *mergedIterator) Seek(key []byte) bool {
}
}
i.dir = dirSOI
if i.next(!i.reverse) {
if i.selectKey() {
i.dir = dirSeek
return true
}
i.dir = dirSOI
return false
}

func (i *mergedIterator) compare(tkey []byte, key []byte, ignoreReverse bool) int {
if ignoreReverse {
return i.cmp.Compare(tkey, key)
}
if tkey == nil && key != nil {
func (i *mergedIterator) compare(key1 []byte, key2 []byte) int {

if key1 == nil && key2 != nil {
return 1
}
if tkey != nil && key == nil {
if key1 != nil && key2 == nil {
return -1
}
result := i.cmp.Compare(tkey, key)
result := i.cmp.Compare(key1, key2)
if i.reverse {
return -result
}
return result
}

func (i *mergedIterator) next(ignoreReverse bool) bool {
func (i *mergedIterator) selectKey() bool {
var key []byte
for x, tkey := range i.keys {
if tkey != nil && (key == nil || i.compare(tkey, key, ignoreReverse) < 0) {
if tkey != nil && (key == nil || i.compare(tkey, key) < 0) {
key = tkey
i.index = x
}
Expand All @@ -141,64 +139,40 @@ func (i *mergedIterator) next(ignoreReverse bool) bool {
return true
}

// Next next key
func (i *mergedIterator) Next() bool {
for {
ok, isrewind := i.nextInternal()
if !ok {
break
}
if isrewind {
return true

if !i.next() {
return false
}
if i.compare(i.Key(), i.prevKey, true) != 0 {

if i.compare(i.Key(), i.prevKey) != 0 {
i.prevKey = cloneByte(i.Key())
return true
}
}
return false
}

func (i *mergedIterator) nextInternal() (bool, bool) {
func (i *mergedIterator) next() bool {
if i.dir == dirEOI || i.err != nil {
return false, false
return false
} else if i.dir == dirReleased {
i.err = ErrIterReleased
return false, false
}
switch i.dir {
case dirSOI:
return i.Rewind(), true
case dirSeek:
if !i.reverse {
break
}
key := append([]byte{}, i.keys[i.index]...)
for x, iter := range i.iters {
if x == i.index {
continue
}
seek := iter.Seek(key)
switch {
case seek && iter.Next(), !seek && iter.Rewind():
i.keys[x] = assertKey(iter.Key())
case i.iterErr(iter):
return false, false
default:
i.keys[x] = nil
}
}
return false
}

x := i.index
iter := i.iters[x]
switch {
case iter.Next():
i.keys[x] = assertKey(iter.Key())
case i.iterErr(iter):
return false, false
return false
default:
i.keys[x] = nil
}
return i.next(false), false
return i.selectKey()
}

func (i *mergedIterator) Key() []byte {
Expand Down Expand Up @@ -247,7 +221,7 @@ func (i *mergedIterator) Error() error {
//
// If strict is true the any 'corruption errors' (i.e errors.IsCorrupted(err) == true)
// won't be ignored and will halt 'merged iterator', otherwise the iterator will
// continue to the next 'input iterator'.
// continue to the selectKey 'input iterator'.
func NewMergedIterator(iters []Iterator) Iterator {
reverse := true
if len(iters) >= 2 {
Expand Down
126 changes: 126 additions & 0 deletions common/db/merge_iter_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"fmt"
"io/ioutil"
"os"
"testing"
Expand Down Expand Up @@ -46,6 +47,14 @@ func newGoLevelDB(t *testing.T) (DB, string) {
return db, dir
}

func newGoBadgerDB(t *testing.T) (DB, string) {
dir, err := ioutil.TempDir("", "badgerdb")
assert.Nil(t, err)
db, err := NewGoBadgerDB("test", dir, 16)
assert.Nil(t, err)
return db, dir
}

func TestMergeIterSeek1(t *testing.T) {
db1 := newGoMemDB(t)
db1.Set([]byte("1"), []byte("1"))
Expand Down Expand Up @@ -279,3 +288,120 @@ func TestIterSearch(t *testing.T) {
assert.Equal(t, "db2-key-3", string(list0[0]))
assert.Equal(t, "db2-key-4", string(list0[1]))
}

func TestMergeIterList(t *testing.T) {
levelDB, dir := newGoLevelDB(t)
testMergeIterList(t, newGoMemDB(t), newGoMemDB(t), levelDB)
_ = os.RemoveAll(dir)
badgerDB, dir := newGoBadgerDB(t)
testMergeIterList(t, newGoMemDB(t), newGoMemDB(t), badgerDB)
_ = os.RemoveAll(dir)
levelDB, dir1 := newGoLevelDB(t)
badgerDB, dir2 := newGoBadgerDB(t)
testMergeIterList(t, badgerDB, levelDB, newGoMemDB(t))
_ = os.RemoveAll(dir1)
_ = os.RemoveAll(dir2)
}

func testMergeIterList(t *testing.T, db1, db2, db3 DB) {

for i := 0; i < 10; i++ {
db3.Set([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("%d", i)))
}
//合并以后:
db := NewMergedIteratorDB([]IteratorDB{db1, db2, db3})
it := NewListHelper(db)

//key9 ~ key1
listAll := func(totalCount int, direction int32) [][]byte {
var values [][]byte
var primary []byte
for i := 0; i < 3; i++ {
data := it.List([]byte("key"), primary, 4, direction)
values = append(values, data...)
primary = []byte(fmt.Sprintf("key%s", data[len(data)-1]))
}
assert.Equal(t, totalCount, len(values))
return values
}

values := listAll(10, ListDESC)
for i, val := range values {
assert.Equal(t, []byte(fmt.Sprintf("%d", 9-i)), val)
}
values = listAll(10, ListASC)
for i, val := range values {
assert.Equal(t, []byte(fmt.Sprintf("%d", i)), val)
}

// db2数据覆盖
db2.Set([]byte("key3"), []byte("33"))
values = listAll(10, ListDESC)
for i, val := range values {
value := []byte(fmt.Sprintf("%d", 9-i))
if i == 6 {
value = []byte("33")
}
assert.Equal(t, value, val)
}
values = listAll(10, ListASC)
for i, val := range values {
value := []byte(fmt.Sprintf("%d", i))
if i == 3 {
value = []byte("33")
}
assert.Equal(t, value, val)
}

// db1数据覆盖
db1.Set([]byte("key3"), []byte("333"))
db1.Set([]byte("key5"), []byte("555"))
values = listAll(10, ListDESC)
for i, val := range values {
value := []byte(fmt.Sprintf("%d", 9-i))
if i == 4 {
value = []byte("555")
}
if i == 6 {
value = []byte("333")
}
assert.Equal(t, value, val)
}
values = listAll(10, ListASC)
for i, val := range values {
value := []byte(fmt.Sprintf("%d", i))
if i == 5 {
value = []byte("555")
}
if i == 3 {
value = []byte("333")
}
assert.Equal(t, value, val)
}

// 新增key
db1.Set([]byte("key91"), []byte("10"))
db2.Set([]byte("key92"), []byte("11"))
values = listAll(12, ListDESC)
for i, val := range values {
value := []byte(fmt.Sprintf("%d", 11-i))
if i == 6 {
value = []byte("555")
}
if i == 8 {
value = []byte("333")
}
assert.Equal(t, value, val)
}
values = listAll(12, ListASC)
for i, val := range values {
value := []byte(fmt.Sprintf("%d", i))
if i == 5 {
value = []byte("555")
}
if i == 3 {
value = []byte("333")
}
assert.Equal(t, value, val)
}
}

0 comments on commit fd5b2dd

Please sign in to comment.