Skip to content

Commit

Permalink
bart: do not alias memory in Union
Browse files Browse the repository at this point in the history
Union was inadvertently reusing a `*node` from the
other table which meant that when either table was
later mutated the change would appear in both tables.

Add a test and fix the condition.

Updates #16
  • Loading branch information
Maisem Ali committed Mar 26, 2024
1 parent de0f3f6 commit 8879d33
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
2 changes: 1 addition & 1 deletion node.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ func (n *node[V]) unionRec(o *node[V]) {
nNode := n.children.get(oAddr)
if nNode == nil {
// union child from oNode into nNode
n.children.insert(oAddr, oNode)
n.children.insert(oAddr, oNode.cloneRec())
} else {
// both nodes have child with addr, call union rec-descent
nNode.unionRec(oNode)
Expand Down
36 changes: 36 additions & 0 deletions node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"bytes"
"fmt"
"math/rand"
"net/netip"
"runtime"
"sort"
"testing"
Expand Down Expand Up @@ -405,3 +406,38 @@ func getsEqual[V comparable](a V, aOK bool, b V, bOK bool) bool {
}
return a == b
}

// TestUnionMemoryAliasing tests that the Union method does not alias memory
// between the two tables.
func TestUnionMemoryAliasing(t *testing.T) {
newTable := func(pfx ...string) *Table[struct{}] {
var t Table[struct{}]
for _, s := range pfx {
t.Insert(netip.MustParsePrefix(s), struct{}{})
}
return &t
}
// First create two tables with disjoint prefixes.
stable := newTable("0.0.0.0/24")
temp := newTable("100.69.1.0/24")

// Verify that the tables are disjoint.
if stable.Overlaps(temp) {
t.Error("stable should not overlap temp")
}

// Now union them.
temp.Union(stable)

// Add a new prefix to temp.
temp.Insert(netip.MustParsePrefix("0.0.1.0/24"), struct{}{})

// Ensure that stable is unchanged.
_, ok := stable.Get(netip.MustParseAddr("0.0.1.1"))
if ok {
t.Error("stable should not contain 0.0.1.1")
}
if stable.OverlapsPrefix(netip.MustParsePrefix("0.0.1.1/32")) {
t.Error("stable should not overlap 0.0.1.1/32")
}
}

0 comments on commit 8879d33

Please sign in to comment.