diff --git a/node.go b/node.go index aad39c7..d359346 100644 --- a/node.go +++ b/node.go @@ -464,7 +464,7 @@ func (n *node[V]) unionRec(o *node[V]) { nNode := n.children.get(oAddr) if nNode == nil { // union child from oNode into nNode - n.children.insert(oAddr, oNode) + n.children.insert(oAddr, oNode.cloneRec()) } else { // both nodes have child with addr, call union rec-descent nNode.unionRec(oNode) diff --git a/node_test.go b/node_test.go index 2eeb107..87e335b 100644 --- a/node_test.go +++ b/node_test.go @@ -13,6 +13,7 @@ import ( "bytes" "fmt" "math/rand" + "net/netip" "runtime" "sort" "testing" @@ -405,3 +406,38 @@ func getsEqual[V comparable](a V, aOK bool, b V, bOK bool) bool { } return a == b } + +// TestUnionMemoryAliasing tests that the Union method does not alias memory +// between the two tables. +func TestUnionMemoryAliasing(t *testing.T) { + newTable := func(pfx ...string) *Table[struct{}] { + var t Table[struct{}] + for _, s := range pfx { + t.Insert(netip.MustParsePrefix(s), struct{}{}) + } + return &t + } + // First create two tables with disjoint prefixes. + stable := newTable("0.0.0.0/24") + temp := newTable("100.69.1.0/24") + + // Verify that the tables are disjoint. + if stable.Overlaps(temp) { + t.Error("stable should not overlap temp") + } + + // Now union them. + temp.Union(stable) + + // Add a new prefix to temp. + temp.Insert(netip.MustParsePrefix("0.0.1.0/24"), struct{}{}) + + // Ensure that stable is unchanged. + _, ok := stable.Get(netip.MustParseAddr("0.0.1.1")) + if ok { + t.Error("stable should not contain 0.0.1.1") + } + if stable.OverlapsPrefix(netip.MustParsePrefix("0.0.1.1/32")) { + t.Error("stable should not overlap 0.0.1.1/32") + } +}