From b0c6e5233cda96b5f9b8f86a46bbbd54be545795 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Thu, 29 Oct 2020 14:38:58 -0700 Subject: [PATCH] Refactor for better caching (#30) This refactor: 1. Ensures we never write anything till we flush. This: 1. Saves us quite a bit of time/gas when making many state modifications. 2. Gives us some room to further optimize batch operations without requiring a network upgrade. 3. Is much easier to reliably re-implement (e.g., in other languages). 2. Completely decodes nodes on load, and re-encodes them on save. This means all bitfield operations are isolated to two functions and bitfields do not need to be maintained in state (they're generated on the fly on flush). 3. Checks a bunch of invariants. We can't check everything, but we can at least avoid doing anything terribly incorrect. --- amt.go | 502 +++++----------------------- amt_test.go | 197 +++++++++-- gen/gen.go | 4 +- go.mod | 2 +- cbor_gen.go => internal/cbor_gen.go | 6 +- internal/internal.go | 33 ++ invalid_test.go | 52 +++ link.go | 32 ++ node.go | 328 ++++++++++++++++++ util.go | 12 + 10 files changed, 715 insertions(+), 453 deletions(-) rename cbor_gen.go => internal/cbor_gen.go (98%) create mode 100644 internal/internal.go create mode 100644 invalid_test.go create mode 100644 link.go create mode 100644 node.go create mode 100644 util.go diff --git a/amt.go b/amt.go index baf3e72..b8bb9b1 100644 --- a/amt.go +++ b/amt.go @@ -4,65 +4,73 @@ import ( "bytes" "context" "fmt" + "sort" cid "github.com/ipfs/go-cid" cbor "github.com/ipfs/go-ipld-cbor" logging "github.com/ipfs/go-log" cbg "github.com/whyrusleeping/cbor-gen" + + "github.com/filecoin-project/go-amt-ipld/v3/internal" ) var log = logging.Logger("amt") -const ( - // Width must be a power of 2. We set this to 8. - maxIndexBits = 63 - widthBits = 3 - width = 1 << widthBits // 8 - bitfieldSize = 1 // ((width - 1) >> 3) + 1 - maxHeight = maxIndexBits/widthBits - 1 // 20 (because the root is at height 0). -) - // MaxIndex is the maximum index for elements in the AMT. This is currently 1^63 // (max int) because the width is 8. That means every "level" consumes 3 bits // from the index, and 63/3 is a nice even 21 -const MaxIndex = uint64(1< maxHeight { - return nil, fmt.Errorf("failed to load AMT: height out of bounds: %d > %d", r.Height, maxHeight) + if r.Height > internal.MaxHeight { + return nil, fmt.Errorf("failed to load AMT: height out of bounds: %d > %d", r.Height, internal.MaxHeight) + } + if nodesForHeight(int(r.Height+1)) < r.Count { + return nil, fmt.Errorf( + "failed to load AMT: not tall enough (%d) for count (%d)", r.Height, r.Count, + ) } - r.store = bs + nd, err := newNode(r.Node, r.Height == 0, r.Height == 0) + if err != nil { + return nil, err + } - return &r, nil + return &Root{ + height: int(r.Height), + count: r.Count, + node: nd, + store: bs, + }, nil +} + +func FromArray(ctx context.Context, bs cbor.IpldStore, vals []cbg.CBORMarshaler) (cid.Cid, error) { + r := NewAMT(bs) + if err := r.BatchSet(ctx, vals); err != nil { + return cid.Undef, err + } + + return r.Flush(ctx) } func (r *Root) Set(ctx context.Context, i uint64, val interface{}) error { @@ -85,46 +93,37 @@ func (r *Root) Set(ctx context.Context, i uint64, val interface{}) error { } } - for i >= nodesForHeight(int(r.Height)+1) { - if !r.Node.empty() { - if err := r.Node.Flush(ctx, r.store, int(r.Height)); err != nil { - return err - } - - c, err := r.store.Put(ctx, &r.Node) - if err != nil { - return err - } - - r.Node = Node{ - Bmap: [...]byte{0x01}, - Links: []cid.Cid{c}, + for i >= nodesForHeight(r.height+1) { + if !r.node.empty() { + nd := r.node + r.node = &node{ + links: [internal.Width]*link{ + 0: { + dirty: true, + cached: nd, + }, + }, } } - r.Height++ + r.height++ } - addVal, err := r.Node.set(ctx, r.store, int(r.Height), i, &cbg.Deferred{Raw: b}) + addVal, err := r.node.set(ctx, r.store, int(r.height), i, &cbg.Deferred{Raw: b}) if err != nil { return err } if addVal { - r.Count++ + // Something is wrong, so we'll just do our best to not overflow. + if r.count >= (MaxIndex - 1) { + return errInvalidCount + } + r.count++ } return nil } -func FromArray(ctx context.Context, bs cbor.IpldStore, vals []cbg.CBORMarshaler) (cid.Cid, error) { - r := NewAMT(bs) - if err := r.BatchSet(ctx, vals); err != nil { - return cid.Undef, err - } - - return r.Flush(ctx) -} - func (r *Root) BatchSet(ctx context.Context, vals []cbg.CBORMarshaler) error { // TODO: there are more optimized ways of doing this method for i, v := range vals { @@ -140,41 +139,28 @@ func (r *Root) Get(ctx context.Context, i uint64, out interface{}) error { return fmt.Errorf("index %d is out of range for the amt", i) } - if i >= nodesForHeight(int(r.Height+1)) { + if i >= nodesForHeight(int(r.height+1)) { return &ErrNotFound{Index: i} } - return r.Node.get(ctx, r.store, int(r.Height), i, out) -} - -func (n *Node) get(ctx context.Context, bs cbor.IpldStore, height int, i uint64, out interface{}) error { - subi := i / nodesForHeight(height) - if !n.isSet(subi) { - return &ErrNotFound{i} - } - if height == 0 { - if err := n.expandValues(); err != nil { - return err - } - - d := n.expVals[i] - - if um, ok := out.(cbg.CBORUnmarshaler); ok { - return um.UnmarshalCBOR(bytes.NewReader(d.Raw)) - } else { - return cbor.DecodeInto(d.Raw, out) - } - } - - subn, err := n.loadNode(ctx, bs, subi, false) - if err != nil { + if found, err := r.node.get(ctx, r.store, int(r.height), i, out); err != nil { return err + } else if !found { + return &ErrNotFound{Index: i} } - - return subn.get(ctx, bs, height-1, i%nodesForHeight(height), out) + return nil } func (r *Root) BatchDelete(ctx context.Context, indices []uint64) error { // TODO: theres a faster way of doing this, but this works for now + + // Sort by index so we can safely implement these optimizations in the future. + less := func(i, j int) bool { return indices[i] < indices[j] } + if !sort.SliceIsSorted(indices, less) { + // Copy first so we don't modify our inputs. + indices = append(indices[0:0:0], indices...) + sort.Slice(indices, less) + } + for _, i := range indices { if err := r.Delete(ctx, i); err != nil { return err @@ -188,60 +174,30 @@ func (r *Root) Delete(ctx context.Context, i uint64) error { if i > MaxIndex { return fmt.Errorf("index %d is out of range for the amt", i) } - //fmt.Printf("i: %d, h: %d, nfh: %d\n", i, r.Height, nodesForHeight(int(r.Height))) - if i >= nodesForHeight(int(r.Height+1)) { + if i >= nodesForHeight(int(r.height+1)) { return &ErrNotFound{i} } - if err := r.Node.delete(ctx, r.store, int(r.Height), i); err != nil { + found, err := r.node.delete(ctx, r.store, int(r.height), i) + if err != nil { return err - } - r.Count-- - - for r.Node.Bmap[0] == 1 && r.Height > 0 { - sub, err := r.Node.loadNode(ctx, r.store, 0, false) - if err != nil { - return err - } - - r.Node = *sub - r.Height-- - } - - return nil -} - -func (n *Node) delete(ctx context.Context, bs cbor.IpldStore, height int, i uint64) error { - subi := i / nodesForHeight(height) - if !n.isSet(subi) { + } else if !found { return &ErrNotFound{i} } - if height == 0 { - if err := n.expandValues(); err != nil { - return err - } - - n.expVals[i] = nil - n.clearBit(i) - - return nil - } - subn, err := n.loadNode(ctx, bs, subi, false) + newHeight, err := r.node.collapse(ctx, r.store, r.height) if err != nil { return err } + r.height = newHeight - if err := subn.delete(ctx, bs, height-1, i%nodesForHeight(height)); err != nil { - return err - } - - if subn.empty() { - n.clearBit(subi) - n.cache[subi] = nil - n.expLinks[subi] = cid.Undef + // Something is very wrong but there's not much we can do. So we perform + // the operation and then tell the user that something is wrong. + if r.count == 0 { + return errInvalidCount } + r.count-- return nil } @@ -254,308 +210,32 @@ func (r *Root) Subtract(ctx context.Context, or *Root) error { } func (r *Root) ForEach(ctx context.Context, cb func(uint64, *cbg.Deferred) error) error { - return r.Node.forEachAt(ctx, r.store, int(r.Height), 0, 0, cb) + return r.node.forEachAt(ctx, r.store, r.height, 0, 0, cb) } func (r *Root) ForEachAt(ctx context.Context, start uint64, cb func(uint64, *cbg.Deferred) error) error { - return r.Node.forEachAt(ctx, r.store, int(r.Height), start, 0, cb) -} - -func (n *Node) forEachAt(ctx context.Context, bs cbor.IpldStore, height int, start, offset uint64, cb func(uint64, *cbg.Deferred) error) error { - if height == 0 { - if err := n.expandValues(); err != nil { - return err - } - - for i, v := range n.expVals { - if v != nil { - ix := offset + uint64(i) - if ix < start { - continue - } - - if err := cb(offset+uint64(i), v); err != nil { - return err - } - } - } - - return nil - } - - if n.cache == nil { - if err := n.expandLinks(); err != nil { - return err - } - } - - subCount := nodesForHeight(height) - for i, v := range n.expLinks { - var sub Node - if n.cache[i] != nil { - sub = *n.cache[i] - } else if v != cid.Undef { - if err := bs.Get(ctx, v, &sub); err != nil { - return err - } - } else { - continue - } - - offs := offset + (uint64(i) * subCount) - nextOffs := offs + subCount - if start >= nextOffs { - continue - } - - if err := sub.forEachAt(ctx, bs, height-1, start, offs, cb); err != nil { - return err - } - } - return nil - + return r.node.forEachAt(ctx, r.store, r.height, start, 0, cb) } func (r *Root) FirstSetIndex(ctx context.Context) (uint64, error) { - return r.Node.firstSetIndex(ctx, r.store, int(r.Height)) -} - -var errNoVals = fmt.Errorf("no values") - -func (n *Node) firstSetIndex(ctx context.Context, bs cbor.IpldStore, height int) (uint64, error) { - if height == 0 { - if err := n.expandValues(); err != nil { - return 0, err - } - for i, v := range n.expVals { - if v != nil { - return uint64(i), nil - } - } - // Would be really weird if we ever actually hit this - return 0, errNoVals - } - - if n.cache == nil { - if err := n.expandLinks(); err != nil { - return 0, err - } - } - - for i := 0; i < width; i++ { - if n.isSet(uint64(i)) { - subn, err := n.loadNode(ctx, bs, uint64(i), false) - if err != nil { - return 0, err - } - - ix, err := subn.firstSetIndex(ctx, bs, height-1) - if err != nil { - return 0, err - } - - subCount := nodesForHeight(height) - return ix + (uint64(i) * subCount), nil - } - } - - return 0, errNoVals -} - -func (n *Node) expandValues() error { - if len(n.expVals) == 0 { - n.expVals = make([]*cbg.Deferred, width) - i := 0 - for x := uint64(0); x < width; x++ { - if n.isSet(x) { - if i >= len(n.Values) { - n.expVals = nil - return fmt.Errorf("bitfield does not match values") - } - n.expVals[x] = n.Values[i] - i++ - } - } - } - return nil -} - -func (n *Node) set(ctx context.Context, bs cbor.IpldStore, height int, i uint64, val *cbg.Deferred) (bool, error) { - //nfh := nodesForHeight(height) - //fmt.Printf("[set] h: %d, i: %d, subi: %d\n", height, i, i/nfh) - if height == 0 { - if err := n.expandValues(); err != nil { - return false, err - } - alreadySet := n.isSet(i) - n.expVals[i] = val - n.setBit(i) - - return !alreadySet, nil - } - - nfh := nodesForHeight(height) - - subn, err := n.loadNode(ctx, bs, i/nfh, true) - if err != nil { - return false, err - } - - return subn.set(ctx, bs, height-1, i%nfh, val) -} - -func (n *Node) isSet(i uint64) bool { - if i > 7 { - panic("cant deal with wider arrays yet") - } - - return len(n.Bmap) != 0 && n.Bmap[0]&byte(1< 7 { - panic("cant deal with wider arrays yet") - } - - if len(n.Bmap) == 0 { - n.Bmap = [...]byte{0} - } - - n.Bmap[0] = n.Bmap[0] | byte(1< 7 { - panic("cant deal with wider arrays yet") - } - - if len(n.Bmap) == 0 { - panic("invariant violated: called clear bit on empty node") - } - - mask := byte(0xff - (1 << i)) - - n.Bmap[0] = n.Bmap[0] & mask -} - -func (n *Node) expandLinks() error { - n.cache = make([]*Node, width) - n.expLinks = make([]cid.Cid, width) - i := 0 - for x := uint64(0); x < width; x++ { - if n.isSet(x) { - if i >= len(n.Links) { - n.cache = nil - n.expLinks = nil - return fmt.Errorf("bitfield does not match links") - } - n.expLinks[x] = n.Links[i] - i++ - } - } - return nil -} - -func (n *Node) loadNode(ctx context.Context, bs cbor.IpldStore, i uint64, create bool) (*Node, error) { - if n.cache == nil { - if err := n.expandLinks(); err != nil { - return nil, err - } - } else { - if n := n.cache[i]; n != nil { - return n, nil - } - } - - var subn *Node - if n.isSet(i) { - var sn Node - if err := bs.Get(ctx, n.expLinks[i], &sn); err != nil { - return nil, err - } - - subn = &sn - } else { - if create { - subn = &Node{} - n.setBit(i) - } else { - return nil, fmt.Errorf("no node found at (sub)index %d", i) - } - } - n.cache[i] = subn - - return subn, nil -} - -func nodesForHeight(height int) uint64 { - heightLogTwo := uint64(widthBits * height) - if heightLogTwo >= 64 { - // Should never happen. Max height is checked at all entry points. - panic("height overflow") - } - return 1 << heightLogTwo + return r.node.firstSetIndex(ctx, r.store, r.height) } func (r *Root) Flush(ctx context.Context) (cid.Cid, error) { - if err := r.Node.Flush(ctx, r.store, int(r.Height)); err != nil { + nd, err := r.node.flush(ctx, r.store, r.height) + if err != nil { return cid.Undef, err } - - return r.store.Put(ctx, r) -} - -func (n *Node) empty() bool { - return len(n.Bmap) == 0 || n.Bmap[0] == 0 -} - -func (n *Node) Flush(ctx context.Context, bs cbor.IpldStore, depth int) error { - if depth == 0 { - if len(n.expVals) == 0 { - return nil - } - n.Bmap = [...]byte{0} - n.Values = nil - for i := uint64(0); i < width; i++ { - v := n.expVals[i] - if v != nil { - n.Values = append(n.Values, v) - n.setBit(i) - } - } - return nil - } - - if len(n.expLinks) == 0 { - // nothing to do! - return nil - } - - n.Bmap = [...]byte{0} - n.Links = nil - - for i := uint64(0); i < width; i++ { - subn := n.cache[i] - if subn != nil { - if err := subn.Flush(ctx, bs, depth-1); err != nil { - return err - } - - c, err := bs.Put(ctx, subn) - if err != nil { - return err - } - n.expLinks[i] = c - } - - l := n.expLinks[i] - if l != cid.Undef { - n.Links = append(n.Links, l) - n.setBit(i) - } + root := internal.Root{ + Height: uint64(r.height), + Count: r.count, + Node: *nd, } + return r.store.Put(ctx, &root) +} - return nil +func (r *Root) Len() uint64 { + return r.count } type ErrNotFound struct { diff --git a/amt_test.go b/amt_test.go index 35bc5c6..6fc8cd4 100644 --- a/amt_test.go +++ b/amt_test.go @@ -1,12 +1,14 @@ package amt import ( + "bytes" "context" "fmt" "math/rand" "testing" "time" + "github.com/filecoin-project/go-amt-ipld/v3/internal" block "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" cbor "github.com/ipfs/go-ipld-cbor" @@ -15,16 +17,28 @@ import ( cbg "github.com/whyrusleeping/cbor-gen" ) +var numbers []cbg.CBORMarshaler + +func init() { + numbers = make([]cbg.CBORMarshaler, 10) + for i := range numbers { + val := cbg.CborInt(i) + numbers[i] = &val + } +} + type mockBlocks struct { - data map[cid.Cid]block.Block + data map[cid.Cid]block.Block + getCount, putCount int } func newMockBlocks() *mockBlocks { - return &mockBlocks{make(map[cid.Cid]block.Block)} + return &mockBlocks{make(map[cid.Cid]block.Block), 0, 0} } func (mb *mockBlocks) Get(c cid.Cid) (block.Block, error) { d, ok := mb.data[c] + mb.getCount++ if ok { return d, nil } @@ -32,10 +46,16 @@ func (mb *mockBlocks) Get(c cid.Cid) (block.Block, error) { } func (mb *mockBlocks) Put(b block.Block) error { + mb.putCount++ mb.data[b.Cid()] = b return nil } +func (mb *mockBlocks) report(b *testing.B) { + b.ReportMetric(float64(mb.getCount)/float64(b.N), "gets/op") + b.ReportMetric(float64(mb.putCount)/float64(b.N), "puts/op") +} + func TestBasicSetGet(t *testing.T) { bs := cbor.NewCborStore(newMockBlocks()) ctx := context.Background() @@ -61,29 +81,61 @@ func TestBasicSetGet(t *testing.T) { } +func TestRoundTrip(t *testing.T) { + bs := cbor.NewCborStore(newMockBlocks()) + ctx := context.Background() + a := NewAMT(bs) + emptyCid, err := a.Flush(ctx) + require.NoError(t, err) + + k := uint64(100000) + assertSet(t, a, k, "foo") + assertDelete(t, a, k) + + c, err := a.Flush(ctx) + require.NoError(t, err) + + require.Equal(t, emptyCid, c) +} + func TestOutOfRange(t *testing.T) { ctx := context.Background() bs := cbor.NewCborStore(newMockBlocks()) a := NewAMT(bs) - err := a.Set(ctx, 1<<63+4, "what is up") + err := a.Set(ctx, 1<<63+4, "what is up 1") if err == nil { t.Fatal("should have failed to set value out of range") } - err = a.Set(ctx, MaxIndex+1, "what is up") + err = a.Set(ctx, MaxIndex+1, "what is up 2") if err == nil { t.Fatal("should have failed to set value out of range") } - err = a.Set(ctx, MaxIndex, "what is up") + err = a.Set(ctx, MaxIndex, "what is up 3") if err != nil { t.Fatal(err) } - if a.Height != maxHeight { + if a.height != internal.MaxHeight { t.Fatal("expected to be at the maximum height") } + + var out string + require.NoError(t, a.Get(ctx, MaxIndex, &out)) + require.Equal(t, "what is up 3", out) + + err = a.Get(ctx, MaxIndex+1, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "out of range") + + err = a.Delete(ctx, MaxIndex) + require.NoError(t, err) + + err = a.Delete(ctx, MaxIndex+1) + require.Error(t, err) + require.Contains(t, err.Error(), "out of range") } func assertDelete(t *testing.T, r *Root, i uint64) { @@ -122,7 +174,7 @@ func assertSet(t *testing.T, r *Root, i uint64, val string) { func assertCount(t testing.TB, r *Root, c uint64) { t.Helper() - if r.Count != c { + if r.count != c { t.Fatal("count is wrong") } } @@ -307,8 +359,8 @@ func TestChaos(t *testing.T) { fail := false correctLen := uint64(len(testMap)) - if correctLen != a.Count { - t.Errorf("bad length before flush, correct: %d, Count: %d, i: %d", correctLen, a.Count, i) + if correctLen != a.Len() { + t.Errorf("bad length before flush, correct: %d, Count: %d, i: %d", correctLen, a.Len(), i) fail = true } @@ -317,8 +369,8 @@ func TestChaos(t *testing.T) { a, err = LoadAMT(ctx, bs, c) assert.NoError(t, err) - if correctLen != a.Count { - t.Errorf("bad length after flush, correct: %d, Count: %d, i: %d", correctLen, a.Count, i) + if correctLen != a.Len() { + t.Errorf("bad length after flush, correct: %d, Count: %d, i: %d", correctLen, a.Len(), i) fail = true } @@ -401,7 +453,7 @@ func TestInsertABunchWithDelete(t *testing.T) { } t.Logf("originSN: %d, removeSN: %d; expected: %d, actual len(n2a): %d", - len(originSet), len(removeSet), len(originSet)-len(removeSet), n2a.Count) + len(originSet), len(removeSet), len(originSet)-len(removeSet), n2a.Len()) assertCount(t, n2a, uint64(len(originSet)-len(removeSet))) for i := uint64(0); i < uint64(num); i++ { @@ -455,23 +507,16 @@ func TestDelete(t *testing.T) { assertGet(ctx, t, a, 3, "cat") assertDelete(t, a, 0) - fmt.Printf("%b\n", a.Node.Bmap[0]) assertDelete(t, a, 2) - fmt.Printf("%b\n", a.Node.Bmap[0]) assertDelete(t, a, 3) - fmt.Printf("%b\n", a.Node.Bmap[0]) assertCount(t, a, 0) fmt.Println("trying deeper operations now") assertSet(t, a, 23, "dog") - fmt.Printf("%b\n", a.Node.Bmap[0]) assertSet(t, a, 24, "dog") - fmt.Printf("%b\n", a.Node.Bmap[0]) - fmt.Println("FAILURE NEXT") assertDelete(t, a, 23) - fmt.Printf("%b\n", a.Node.Bmap[0]) assertCount(t, a, 1) @@ -541,21 +586,50 @@ func TestDeleteReduceHeight(t *testing.T) { } func BenchmarkAMTInsertBulk(b *testing.B) { - bs := cbor.NewCborStore(newMockBlocks()) + mock := newMockBlocks() + defer mock.report(b) + + bs := cbor.NewCborStore(mock) ctx := context.Background() - a := NewAMT(bs) - for i := uint64(b.N); i > 0; i-- { - if err := a.Set(ctx, i, "some value"); err != nil { + for i := 0; i < b.N; i++ { + a := NewAMT(bs) + + num := uint64(5000) + + for i := uint64(0); i < num; i++ { + if err := a.Set(ctx, i, "foo foo bar"); err != nil { + b.Fatal(err) + } + } + + for i := uint64(0); i < num; i++ { + assertGet(ctx, b, a, i, "foo foo bar") + } + + c, err := a.Flush(ctx) + if err != nil { b.Fatal(err) } - } - assertCount(b, a, uint64(b.N)) + na, err := LoadAMT(ctx, bs, c) + if err != nil { + b.Fatal(err) + } + + for i := uint64(0); i < num; i++ { + assertGet(ctx, b, na, i, "foo foo bar") + } + + assertCount(b, na, num) + } } func BenchmarkAMTLoadAndInsert(b *testing.B) { - bs := cbor.NewCborStore(newMockBlocks()) + mock := newMockBlocks() + defer mock.report(b) + + bs := cbor.NewCborStore(mock) ctx := context.Background() a := NewAMT(bs) @@ -699,8 +773,9 @@ func TestFirstSetIndex(t *testing.T) { bs := cbor.NewCborStore(newMockBlocks()) ctx := context.Background() - vals := []uint64{0, 1, 5, width, width + 1, 276, 1234, 62881923} - for _, v := range vals { + vals := []uint64{0, 1, 5, internal.Width, internal.Width + 1, 276, 1234, 62881923} + for i, v := range vals { + t.Log(i, v) a := NewAMT(bs) if err := a.Set(ctx, v, fmt.Sprint(v)); err != nil { t.Fatal(err) @@ -733,6 +808,10 @@ func TestFirstSetIndex(t *testing.T) { if fsi != v { t.Fatal("got wrong index out after serialization") } + err = after.Delete(ctx, v) + require.NoError(t, err) + fsi, err = after.FirstSetIndex(ctx) + require.Error(t, err) } } @@ -765,17 +844,63 @@ func TestEmptyCIDStability(t *testing.T) { func TestBadBitfield(t *testing.T) { bs := cbor.NewCborStore(newMockBlocks()) ctx := context.Background() - a := NewAMT(bs) - a.Node.Bmap[0] = 0xff - a.Height = 10 - a.Count = 10 - c, err := bs.Put(ctx, a) + subnode, err := bs.Put(ctx, new(internal.Node)) require.NoError(t, err) - a2, err := LoadAMT(ctx, bs, c) + var root internal.Root + root.Node.Bmap[0] = 0xff + root.Node.Links = append(root.Node.Links, subnode) + root.Height = 10 + root.Count = 10 + c, err := bs.Put(ctx, &root) require.NoError(t, err) - var out string - err = a2.Get(ctx, 100, &out) + + _, err = LoadAMT(ctx, bs, c) require.Error(t, err) } + +func TestFromArray(t *testing.T) { + bs := cbor.NewCborStore(newMockBlocks()) + ctx := context.Background() + + c, err := FromArray(ctx, bs, numbers) + require.NoError(t, err) + a, err := LoadAMT(ctx, bs, c) + require.NoError(t, err) + assertEquals(ctx, t, a, numbers) + assertCount(t, a, 10) +} + +func TestBatch(t *testing.T) { + bs := cbor.NewCborStore(newMockBlocks()) + ctx := context.Background() + a := NewAMT(bs) + + require.NoError(t, a.BatchSet(ctx, numbers)) + assertEquals(ctx, t, a, numbers) + + c, err := a.Flush(ctx) + if err != nil { + t.Fatal(err) + } + + clean, err := LoadAMT(ctx, bs, c) + if err != nil { + t.Fatal(err) + } + + assertEquals(ctx, t, clean, numbers) + require.NoError(t, a.BatchDelete(ctx, []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})) + assertCount(t, a, 0) +} + +func assertEquals(ctx context.Context, t testing.TB, a *Root, values []cbg.CBORMarshaler) { + require.NoError(t, a.ForEach(ctx, func(i uint64, val *cbg.Deferred) error { + var buf bytes.Buffer + require.NoError(t, values[i].MarshalCBOR(&buf)) + require.Equal(t, buf.Bytes(), val.Raw) + return nil + })) + assertCount(t, a, uint64(len(values))) +} diff --git a/gen/gen.go b/gen/gen.go index 2ba7170..d374003 100644 --- a/gen/gen.go +++ b/gen/gen.go @@ -3,11 +3,11 @@ package main import ( cbg "github.com/whyrusleeping/cbor-gen" - "github.com/filecoin-project/go-amt-ipld/v2" + "github.com/filecoin-project/go-amt-ipld/v3/internal" ) func main() { - if err := cbg.WriteTupleEncodersToFile("cbor_gen.go", "amt", amt.Root{}, amt.Node{}); err != nil { + if err := cbg.WriteTupleEncodersToFile("internal/cbor_gen.go", "internal", internal.Root{}, internal.Node{}); err != nil { panic(err) } } diff --git a/go.mod b/go.mod index 76c2351..2e162ab 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/filecoin-project/go-amt-ipld/v2 +module github.com/filecoin-project/go-amt-ipld/v3 go 1.12 diff --git a/cbor_gen.go b/internal/cbor_gen.go similarity index 98% rename from cbor_gen.go rename to internal/cbor_gen.go index f316bee..1d051c3 100644 --- a/cbor_gen.go +++ b/internal/cbor_gen.go @@ -1,6 +1,6 @@ // Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. -package amt +package internal import ( "fmt" @@ -38,7 +38,7 @@ func (t *Root) MarshalCBOR(w io.Writer) error { return err } - // t.Node (amt.Node) (struct) + // t.Node (internal.Node) (struct) if err := t.Node.MarshalCBOR(w); err != nil { return err } @@ -91,7 +91,7 @@ func (t *Root) UnmarshalCBOR(r io.Reader) error { t.Count = uint64(extra) } - // t.Node (amt.Node) (struct) + // t.Node (internal.Node) (struct) { diff --git a/internal/internal.go b/internal/internal.go new file mode 100644 index 0000000..6f6236c --- /dev/null +++ b/internal/internal.go @@ -0,0 +1,33 @@ +package internal + +import ( + cid "github.com/ipfs/go-cid" + cbg "github.com/whyrusleeping/cbor-gen" +) + +const ( + // Width must be a power of 2. We set this to 8. + MaxIndexBits = 63 + WidthBits = 3 + Width = 1 << WidthBits // 8 + BitfieldSize = 1 // ((width - 1) >> 3) + 1 + MaxHeight = MaxIndexBits/WidthBits - 1 // 20 (because the root is at height 0). +) + +func init() { + if BitfieldSize != ((Width-1)>>3)+1 { + panic("bitfield size must match width") + } +} + +type Node struct { + Bmap [BitfieldSize]byte + Links []cid.Cid + Values []*cbg.Deferred +} + +type Root struct { + Height uint64 + Count uint64 + Node Node +} diff --git a/invalid_test.go b/invalid_test.go new file mode 100644 index 0000000..3f5e9aa --- /dev/null +++ b/invalid_test.go @@ -0,0 +1,52 @@ +package amt + +import ( + "context" + "testing" + + cbor "github.com/ipfs/go-ipld-cbor" + "github.com/stretchr/testify/require" +) + +func TestInvalidHeightEmpty(t *testing.T) { + bs := cbor.NewCborStore(newMockBlocks()) + ctx := context.Background() + a := NewAMT(bs) + a.height = 1 + c, err := a.Flush(ctx) + require.NoError(t, err) + _, err = LoadAMT(ctx, bs, c) + require.Error(t, err) +} + +func TestInvalidHeightSingle(t *testing.T) { + bs := cbor.NewCborStore(newMockBlocks()) + ctx := context.Background() + a := NewAMT(bs) + err := a.Set(ctx, 0, 0) + require.NoError(t, err) + + a.height = 1 + c, err := a.Flush(ctx) + require.NoError(t, err) + _, err = LoadAMT(ctx, bs, c) + require.Error(t, err) +} + +func TestInvalidHeightTall(t *testing.T) { + bs := cbor.NewCborStore(newMockBlocks()) + ctx := context.Background() + a := NewAMT(bs) + err := a.Set(ctx, 15, 0) + require.NoError(t, err) + + a.height = 2 + c, err := a.Flush(ctx) + require.NoError(t, err) + after, err := LoadAMT(ctx, bs, c) + require.NoError(t, err) + + var out int + err = after.Get(ctx, 31, &out) + require.Error(t, err) +} diff --git a/link.go b/link.go new file mode 100644 index 0000000..fb08f2b --- /dev/null +++ b/link.go @@ -0,0 +1,32 @@ +package amt + +import ( + "context" + + "github.com/filecoin-project/go-amt-ipld/v3/internal" + cid "github.com/ipfs/go-cid" + cbor "github.com/ipfs/go-ipld-cbor" +) + +type link struct { + cid cid.Cid + + cached *node + dirty bool +} + +func (l *link) load(ctx context.Context, bs cbor.IpldStore, height int) (*node, error) { + if l.cached == nil { + var nd internal.Node + if err := bs.Get(ctx, l.cid, &nd); err != nil { + return nil, err + } + + n, err := newNode(nd, false, height == 0) + if err != nil { + return nil, err + } + l.cached = n + } + return l.cached, nil +} diff --git a/node.go b/node.go new file mode 100644 index 0000000..1fb018d --- /dev/null +++ b/node.go @@ -0,0 +1,328 @@ +package amt + +import ( + "bytes" + "context" + "errors" + "fmt" + + "github.com/filecoin-project/go-amt-ipld/v3/internal" + "github.com/ipfs/go-cid" + cbor "github.com/ipfs/go-ipld-cbor" + cbg "github.com/whyrusleeping/cbor-gen" +) + +type node struct { + // these may both be nil if the node is empty (a root node) + links [internal.Width]*link + values [internal.Width]*cbg.Deferred +} + +var ( + errEmptyNode = errors.New("unexpected empty amt node") + errUndefinedCID = errors.New("amt node has undefined CID") + errLinksAndValues = errors.New("amt node has both links and values") + errLeafUnexpected = errors.New("amt leaf not expected at height") + errLeafExpected = errors.New("amt expected at height") + errInvalidCount = errors.New("amt count does not match number of elements") +) + +func newNode(nd internal.Node, allowEmpty, expectLeaf bool) (*node, error) { + if len(nd.Links) > 0 && len(nd.Values) > 0 { + return nil, errLinksAndValues + } + + i := 0 + n := new(node) + if len(nd.Values) > 0 { + if !expectLeaf { + return nil, errLeafUnexpected + } + for x := uint(0); x < internal.Width; x++ { + if nd.Bmap[x/8]&(1<<(x%8)) > 0 { + if i >= len(nd.Values) { + return nil, fmt.Errorf("expected at least %d values, found %d", i+1, len(nd.Values)) + } + n.values[x] = nd.Values[i] + i++ + } + } + if i != len(nd.Values) { + return nil, fmt.Errorf("expected %d values, got %d", i, len(nd.Values)) + } + } else if len(nd.Links) > 0 { + if expectLeaf { + return nil, errLeafExpected + } + + for x := uint(0); x < internal.Width; x++ { + if nd.Bmap[x/8]&(1<<(x%8)) > 0 { + if i >= len(nd.Links) { + return nil, fmt.Errorf("expected at least %d links, found %d", i+1, len(nd.Links)) + } + c := nd.Links[i] + if !c.Defined() { + return nil, errUndefinedCID + } + // TODO: check link hash function. + prefix := c.Prefix() + if prefix.Codec != cid.DagCBOR { + return nil, fmt.Errorf("internal amt nodes must be cbor, found %d", prefix.Codec) + } + n.links[x] = &link{cid: c} + i++ + } + } + if i != len(nd.Links) { + return nil, fmt.Errorf("expected %d links, got %d", i, len(nd.Links)) + } + } else if !allowEmpty { + return nil, errEmptyNode + } + return n, nil +} + +func (nd *node) collapse(ctx context.Context, bs cbor.IpldStore, height int) (int, error) { + // If we have any links going "to the right", we can't collapse any + // more. + for _, l := range nd.links[1:] { + if l != nil { + return height, nil + } + } + + // If we have _no_ links, we've collapsed everything. + if nd.links[0] == nil { + return 0, nil + } + + // only one child, collapse it. + + subn, err := nd.links[0].load(ctx, bs, height-1) + if err != nil { + return 0, err + } + + // Collapse recursively. + newHeight, err := subn.collapse(ctx, bs, height-1) + if err != nil { + return 0, err + } + + *nd = *subn + + return newHeight, nil +} + +func (nd *node) empty() bool { + return nd.links == [len(nd.links)]*link{} && nd.values == [len(nd.links)]*cbg.Deferred{} +} + +func (n *node) get(ctx context.Context, bs cbor.IpldStore, height int, i uint64, out interface{}) (bool, error) { + if height == 0 { + d := n.values[i] + if d == nil { + return false, nil + } + if um, ok := out.(cbg.CBORUnmarshaler); ok { + return true, um.UnmarshalCBOR(bytes.NewReader(d.Raw)) + } + return true, cbor.DecodeInto(d.Raw, out) + } + nfh := nodesForHeight(height) + ln := n.links[i/nfh] + if ln == nil { + return false, nil + } + subn, err := ln.load(ctx, bs, height-1) + if err != nil { + return false, err + } + + return subn.get(ctx, bs, height-1, i%nfh, out) +} + +func (n *node) delete(ctx context.Context, bs cbor.IpldStore, height int, i uint64) (bool, error) { + if height == 0 { + if n.values[i] == nil { + return false, nil + } + + n.values[i] = nil + return true, nil + } + + nfh := nodesForHeight(height) + subi := i / nfh + + ln := n.links[subi] + if ln == nil { + return false, nil + } + subn, err := ln.load(ctx, bs, height-1) + if err != nil { + return false, err + } + + if deleted, err := subn.delete(ctx, bs, height-1, i%nfh); err != nil { + return false, err + } else if !deleted { + return false, nil + } + + if subn.empty() { + n.links[subi] = nil + } else { + ln.dirty = true + } + + return true, nil +} + +func (n *node) forEachAt(ctx context.Context, bs cbor.IpldStore, height int, start, offset uint64, cb func(uint64, *cbg.Deferred) error) error { + if height == 0 { + for i, v := range n.values { + if v != nil { + ix := offset + uint64(i) + if ix < start { + continue + } + + if err := cb(offset+uint64(i), v); err != nil { + return err + } + } + } + + return nil + } + + subCount := nodesForHeight(height) + for i, ln := range n.links { + if ln == nil { + continue + } + subn, err := ln.load(ctx, bs, height-1) + if err != nil { + return err + } + + offs := offset + (uint64(i) * subCount) + nextOffs := offs + subCount + if start >= nextOffs { + continue + } + + if err := subn.forEachAt(ctx, bs, height-1, start, offs, cb); err != nil { + return err + } + } + return nil + +} + +var errNoVals = fmt.Errorf("no values") + +func (n *node) firstSetIndex(ctx context.Context, bs cbor.IpldStore, height int) (uint64, error) { + if height == 0 { + for i, v := range n.values { + if v != nil { + return uint64(i), nil + } + } + // Empty array. + return 0, errNoVals + } + + for i, ln := range n.links { + if ln == nil { + // nothing here. + continue + } + subn, err := ln.load(ctx, bs, height-1) + if err != nil { + return 0, err + } + ix, err := subn.firstSetIndex(ctx, bs, height-1) + if err != nil { + return 0, err + } + + subCount := nodesForHeight(height) + return ix + (uint64(i) * subCount), nil + } + + return 0, errNoVals +} + +func (n *node) set(ctx context.Context, bs cbor.IpldStore, height int, i uint64, val *cbg.Deferred) (bool, error) { + if height == 0 { + alreadySet := n.values[i] != nil + n.values[i] = val + return !alreadySet, nil + } + + nfh := nodesForHeight(height) + + // Load but don't mark dirty or actually link in any _new_ intermediate + // nodes. We'll do that on return if nothing goes wrong. + ln := n.links[i/nfh] + if ln == nil { + ln = &link{cached: new(node)} + } + subn, err := ln.load(ctx, bs, height-1) + if err != nil { + return false, err + } + + nodeAdded, err := subn.set(ctx, bs, height-1, i%nfh, val) + if err != nil { + return false, err + } + + // Make all modifications on the way back up if there was no error. + ln.dirty = true // only mark dirty on success. + n.links[i/nfh] = ln + + return nodeAdded, nil +} + +func (n *node) flush(ctx context.Context, bs cbor.IpldStore, height int) (*internal.Node, error) { + var nd internal.Node + if height == 0 { + for i, val := range n.values { + if val == nil { + continue + } + nd.Values = append(nd.Values, val) + nd.Bmap[i/8] |= 1 << (uint(i) % 8) + } + return &nd, nil + } + + for i, ln := range n.links { + if ln == nil { + continue + } + if ln.dirty { + if ln.cached == nil { + return nil, fmt.Errorf("expected dirty node to be cached") + } + subn, err := ln.cached.flush(ctx, bs, height-1) + if err != nil { + return nil, err + } + cid, err := bs.Put(ctx, subn) + if err != nil { + return nil, err + } + + ln.cid = cid + ln.dirty = false + } + nd.Links = append(nd.Links, ln.cid) + nd.Bmap[i/8] |= 1 << (uint(i) % 8) + } + + return &nd, nil +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..0eded5b --- /dev/null +++ b/util.go @@ -0,0 +1,12 @@ +package amt + +import "github.com/filecoin-project/go-amt-ipld/v3/internal" + +func nodesForHeight(height int) uint64 { + heightLogTwo := uint64(internal.WidthBits * height) + if heightLogTwo >= 64 { + // Should never happen. Max height is checked at all entry points. + panic("height overflow") + } + return 1 << heightLogTwo +}