diff --git a/bitset.go b/bitset.go index 8d42535..bbb75ff 100644 --- a/bitset.go +++ b/bitset.go @@ -1185,3 +1185,121 @@ func (b *BitSet) Select(index uint) uint { } return b.length } + +// top detects the top bit set +func (b *BitSet) top() (uint, bool) { + panicIfNull(b) + + idx := len(b.set) - 1 + for ; idx >= 0 && b.set[idx] == 0; idx-- { + } + + // no set bits + if idx < 0 { + return 0, false + } + + return uint(idx)*wordSize + len64(b.set[idx]) - 1, true +} + +// ShiftLeft shifts the bitset like << operation would do. +// +// Left shift may require bitset size extension. We try to avoid the +// unnecessary memory operations by detecting the leftmost set bit. +// The function will panic if shift causes excess of capacity. +func (b *BitSet) ShiftLeft(bits uint) { + panicIfNull(b) + + if bits == 0 { + return + } + + top, ok := b.top() + if !ok { + return + } + + // capacity check + if top+bits >= Cap() { + panic("You are exceeding the capacity") + } + + // destination set + dst := b.set + + // not using extendSet() to avoid unneeded data copying + nsize := wordsNeeded(top + bits) + if len(b.set) < nsize { + dst = make([]uint64, nsize, 2*nsize) + } + if top+bits >= b.length { + b.length = top + bits + 1 + } + + pad, idx := top%wordSize, top>>log2WordSize + shift, pages := bits%wordSize, bits>>log2WordSize + if bits%wordSize == 0 { // happy case: just add pages + copy(dst[pages:nsize], b.set) + } else { + if pad+shift >= wordSize { + dst[idx+pages+1] = b.set[idx] >> (wordSize - shift) + } + + for i := int(idx); i >= 0; i-- { + if i > 0 { + dst[i+int(pages)] = (b.set[i] << shift) | (b.set[i-1] >> (wordSize - shift)) + } else { + dst[i+int(pages)] = b.set[i] << shift + } + } + } + + // zeroing extra pages + for i := 0; i < int(pages); i++ { + dst[i] = 0 + } + + b.set = dst +} + +// ShiftRight shifts the bitset like >> operation would do. +func (b *BitSet) ShiftRight(bits uint) { + panicIfNull(b) + + if bits == 0 { + return + } + + top, ok := b.top() + if !ok { + return + } + + if bits >= top { + b.set = make([]uint64, wordsNeeded(b.length)) + return + } + + pad, idx := top%wordSize, top>>log2WordSize + shift, pages := bits%wordSize, bits>>log2WordSize + if bits%wordSize == 0 { // happy case: just clear pages + b.set = b.set[pages:] + b.length -= pages * wordSize + } else { + for i := 0; i <= int(idx-pages); i++ { + if i < int(idx-pages) { + b.set[i] = (b.set[i+int(pages)] >> shift) | (b.set[i+int(pages)+1] << (wordSize - shift)) + } else { + b.set[i] = b.set[i+int(pages)] >> shift + } + } + + if pad < shift { + b.set[int(idx-pages)] = 0 + } + } + + for i := int(idx-pages) + 1; i <= int(idx); i++ { + b.set[i] = 0 + } +} diff --git a/bitset_test.go b/bitset_test.go index 34d628a..4114bec 100644 --- a/bitset_test.go +++ b/bitset_test.go @@ -1963,3 +1963,75 @@ func TestSetAll(t *testing.T) { test(fmt.Sprintf("length %d", length), New(length), length) } } + +func TestShiftLeft(t *testing.T) { + data := []uint{5, 28, 45, 72, 89} + + test := func(name string, bits uint) { + t.Run(name, func(t *testing.T) { + b := New(200) + for _, i := range data { + b.Set(i) + } + + b.ShiftLeft(bits) + + if int(b.Count()) != len(data) { + t.Error("bad bits count") + } + + for _, i := range data { + if !b.Test(i + bits) { + t.Errorf("bit %v is not set", i+bits) + } + } + }) + } + + test("zero", 0) + test("no page change", 19) + test("shift to full page", 38) + test("full page shift", 64) + test("no page split", 80) + test("with page split", 114) + test("with extension", 242) +} + +func TestShiftRight(t *testing.T) { + data := []uint{5, 28, 45, 72, 89} + + test := func(name string, bits uint) { + t.Run(name, func(t *testing.T) { + b := New(200) + for _, i := range data { + b.Set(i) + } + + b.ShiftRight(bits) + + count := 0 + for _, i := range data { + if i > bits { + count++ + + if !b.Test(i - bits) { + t.Errorf("bit %v is not set", i-bits) + } + } + } + + if int(b.Count()) != count { + t.Error("bad bits count") + } + }) + } + + test("zero", 0) + test("no page change", 3) + test("no page split", 20) + test("with page split", 40) + test("full page shift", 64) + test("with extension", 70) + test("full shift", 89) + test("remove all", 242) +} diff --git a/leading_zeros_18.go b/leading_zeros_18.go new file mode 100644 index 0000000..cd10a88 --- /dev/null +++ b/leading_zeros_18.go @@ -0,0 +1,43 @@ +//go:build !go1.9 +// +build !go1.9 + +package bitset + +var len8tab = "" + + "\x00\x01\x02\x02\x03\x03\x03\x03\x04\x04\x04\x04\x04\x04\x04\x04" + + "\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05" + + "\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06" + + "\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06" + + "\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" + + "\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" + + "\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" + + "\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + +// Len64 returns the minimum number of bits required to represent x; the result is 0 for x == 0. +func len64(x uint64) (n uint) { + if x >= 1<<32 { + x >>= 32 + n = 32 + } + if x >= 1<<16 { + x >>= 16 + n += 16 + } + if x >= 1<<8 { + x >>= 8 + n += 8 + } + return n + uint(len8tab[x]) +} + +func leadingZeroes64(v uint64) uint { + return 64 - len64(x) +} diff --git a/leading_zeros_19.go b/leading_zeros_19.go new file mode 100644 index 0000000..74a7942 --- /dev/null +++ b/leading_zeros_19.go @@ -0,0 +1,14 @@ +//go:build go1.9 +// +build go1.9 + +package bitset + +import "math/bits" + +func len64(v uint64) uint { + return uint(bits.Len64(v)) +} + +func leadingZeroes64(v uint64) uint { + return uint(bits.LeadingZeros64(v)) +}