diff --git a/common/db/merge_iter.go b/common/db/merge_iter.go index fd8bf43e00..784dae8a00 100644 --- a/common/db/merge_iter.go +++ b/common/db/merge_iter.go @@ -76,7 +76,7 @@ func (i *mergedIterator) Rewind() bool { } } i.dir = dirSOI - return i.next(false) + return i.selectKey() } func (i *mergedIterator) Seek(key []byte) bool { @@ -97,7 +97,7 @@ func (i *mergedIterator) Seek(key []byte) bool { } } i.dir = dirSOI - if i.next(!i.reverse) { + if i.selectKey() { i.dir = dirSeek return true } @@ -105,27 +105,25 @@ func (i *mergedIterator) Seek(key []byte) bool { return false } -func (i *mergedIterator) compare(tkey []byte, key []byte, ignoreReverse bool) int { - if ignoreReverse { - return i.cmp.Compare(tkey, key) - } - if tkey == nil && key != nil { +func (i *mergedIterator) compare(key1 []byte, key2 []byte) int { + + if key1 == nil && key2 != nil { return 1 } - if tkey != nil && key == nil { + if key1 != nil && key2 == nil { return -1 } - result := i.cmp.Compare(tkey, key) + result := i.cmp.Compare(key1, key2) if i.reverse { return -result } return result } -func (i *mergedIterator) next(ignoreReverse bool) bool { +func (i *mergedIterator) selectKey() bool { var key []byte for x, tkey := range i.keys { - if tkey != nil && (key == nil || i.compare(tkey, key, ignoreReverse) < 0) { + if tkey != nil && (key == nil || i.compare(tkey, key) < 0) { key = tkey i.index = x } @@ -141,64 +139,40 @@ func (i *mergedIterator) next(ignoreReverse bool) bool { return true } +// Next next key func (i *mergedIterator) Next() bool { for { - ok, isrewind := i.nextInternal() - if !ok { - break - } - if isrewind { - return true + + if !i.next() { + return false } - if i.compare(i.Key(), i.prevKey, true) != 0 { + + if i.compare(i.Key(), i.prevKey) != 0 { i.prevKey = cloneByte(i.Key()) return true } } - return false } -func (i *mergedIterator) nextInternal() (bool, bool) { +func (i *mergedIterator) next() bool { if i.dir == dirEOI || i.err != nil { - return false, false + return false } else if i.dir == dirReleased { i.err = ErrIterReleased - return false, false - } - switch i.dir { - case dirSOI: - return i.Rewind(), true - case dirSeek: - if !i.reverse { - break - } - key := append([]byte{}, i.keys[i.index]...) - for x, iter := range i.iters { - if x == i.index { - continue - } - seek := iter.Seek(key) - switch { - case seek && iter.Next(), !seek && iter.Rewind(): - i.keys[x] = assertKey(iter.Key()) - case i.iterErr(iter): - return false, false - default: - i.keys[x] = nil - } - } + return false } + x := i.index iter := i.iters[x] switch { case iter.Next(): i.keys[x] = assertKey(iter.Key()) case i.iterErr(iter): - return false, false + return false default: i.keys[x] = nil } - return i.next(false), false + return i.selectKey() } func (i *mergedIterator) Key() []byte { @@ -247,7 +221,7 @@ func (i *mergedIterator) Error() error { // // If strict is true the any 'corruption errors' (i.e errors.IsCorrupted(err) == true) // won't be ignored and will halt 'merged iterator', otherwise the iterator will -// continue to the next 'input iterator'. +// continue to the selectKey 'input iterator'. func NewMergedIterator(iters []Iterator) Iterator { reverse := true if len(iters) >= 2 { diff --git a/common/db/merge_iter_test.go b/common/db/merge_iter_test.go index 39ffdf0f84..cf556a6e60 100644 --- a/common/db/merge_iter_test.go +++ b/common/db/merge_iter_test.go @@ -1,6 +1,7 @@ package db import ( + "fmt" "io/ioutil" "os" "testing" @@ -46,6 +47,14 @@ func newGoLevelDB(t *testing.T) (DB, string) { return db, dir } +func newGoBadgerDB(t *testing.T) (DB, string) { + dir, err := ioutil.TempDir("", "badgerdb") + assert.Nil(t, err) + db, err := NewGoBadgerDB("test", dir, 16) + assert.Nil(t, err) + return db, dir +} + func TestMergeIterSeek1(t *testing.T) { db1 := newGoMemDB(t) db1.Set([]byte("1"), []byte("1")) @@ -279,3 +288,120 @@ func TestIterSearch(t *testing.T) { assert.Equal(t, "db2-key-3", string(list0[0])) assert.Equal(t, "db2-key-4", string(list0[1])) } + +func TestMergeIterList(t *testing.T) { + levelDB, dir := newGoLevelDB(t) + testMergeIterList(t, newGoMemDB(t), newGoMemDB(t), levelDB) + _ = os.RemoveAll(dir) + badgerDB, dir := newGoBadgerDB(t) + testMergeIterList(t, newGoMemDB(t), newGoMemDB(t), badgerDB) + _ = os.RemoveAll(dir) + levelDB, dir1 := newGoLevelDB(t) + badgerDB, dir2 := newGoBadgerDB(t) + testMergeIterList(t, badgerDB, levelDB, newGoMemDB(t)) + _ = os.RemoveAll(dir1) + _ = os.RemoveAll(dir2) +} + +func testMergeIterList(t *testing.T, db1, db2, db3 DB) { + + for i := 0; i < 10; i++ { + db3.Set([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("%d", i))) + } + //合并以后: + db := NewMergedIteratorDB([]IteratorDB{db1, db2, db3}) + it := NewListHelper(db) + + //key9 ~ key1 + listAll := func(totalCount int, direction int32) [][]byte { + var values [][]byte + var primary []byte + for i := 0; i < 3; i++ { + data := it.List([]byte("key"), primary, 4, direction) + values = append(values, data...) + primary = []byte(fmt.Sprintf("key%s", data[len(data)-1])) + } + assert.Equal(t, totalCount, len(values)) + return values + } + + values := listAll(10, ListDESC) + for i, val := range values { + assert.Equal(t, []byte(fmt.Sprintf("%d", 9-i)), val) + } + values = listAll(10, ListASC) + for i, val := range values { + assert.Equal(t, []byte(fmt.Sprintf("%d", i)), val) + } + + // db2数据覆盖 + db2.Set([]byte("key3"), []byte("33")) + values = listAll(10, ListDESC) + for i, val := range values { + value := []byte(fmt.Sprintf("%d", 9-i)) + if i == 6 { + value = []byte("33") + } + assert.Equal(t, value, val) + } + values = listAll(10, ListASC) + for i, val := range values { + value := []byte(fmt.Sprintf("%d", i)) + if i == 3 { + value = []byte("33") + } + assert.Equal(t, value, val) + } + + // db1数据覆盖 + db1.Set([]byte("key3"), []byte("333")) + db1.Set([]byte("key5"), []byte("555")) + values = listAll(10, ListDESC) + for i, val := range values { + value := []byte(fmt.Sprintf("%d", 9-i)) + if i == 4 { + value = []byte("555") + } + if i == 6 { + value = []byte("333") + } + assert.Equal(t, value, val) + } + values = listAll(10, ListASC) + for i, val := range values { + value := []byte(fmt.Sprintf("%d", i)) + if i == 5 { + value = []byte("555") + } + if i == 3 { + value = []byte("333") + } + assert.Equal(t, value, val) + } + + // 新增key + db1.Set([]byte("key91"), []byte("10")) + db2.Set([]byte("key92"), []byte("11")) + values = listAll(12, ListDESC) + for i, val := range values { + value := []byte(fmt.Sprintf("%d", 11-i)) + if i == 6 { + value = []byte("555") + } + if i == 8 { + value = []byte("333") + } + assert.Equal(t, value, val) + } + values = listAll(12, ListASC) + for i, val := range values { + value := []byte(fmt.Sprintf("%d", i)) + if i == 5 { + value = []byte("555") + } + if i == 3 { + value = []byte("333") + } + assert.Equal(t, value, val) + } +}