diff --git a/packages/ledger/branchdag/storage.go b/packages/ledger/branchdag/storage.go index a158ade856..c85463ebec 100644 --- a/packages/ledger/branchdag/storage.go +++ b/packages/ledger/branchdag/storage.go @@ -6,8 +6,6 @@ import ( "github.com/cockroachdb/errors" "github.com/iotaledger/hive.go/cerrors" "github.com/iotaledger/hive.go/generics/objectstorage" - "github.com/iotaledger/hive.go/generics/set" - "github.com/iotaledger/hive.go/generics/walker" "github.com/iotaledger/goshimmer/packages/database" ) @@ -109,61 +107,6 @@ func (s *Storage) ConflictMembers(conflictID ConflictID) (cachedConflictMembers return } -// ForEachConflictingBranchID executes the callback for each Branch that is conflicting with the Branch -// identified by the given BranchID. -func (s *Storage) ForEachConflictingBranchID(branchID BranchID, callback func(conflictingBranchID BranchID) bool) { - abort := false - s.Branch(branchID).Consume(func(branch *Branch) { - _ = branch.Conflicts().ForEach(func(conflictID ConflictID) (err error) { - s.ConflictMembers(conflictID).Consume(func(conflictMember *ConflictMember) { - if abort || conflictMember.BranchID() == branchID { - return - } - - abort = !callback(conflictMember.BranchID()) - }) - - if abort { - return errors.New("abort") - } - - return nil - }) - }) -} - -// ForEachConnectedConflictingBranchID executes the callback for each Branch that is connected through a chain -// of intersecting ConflictSets. -func (s *Storage) ForEachConnectedConflictingBranchID(branchID BranchID, callback func(conflictingBranchID BranchID)) { - traversedBranches := set.New[BranchID]() - conflictSetsWalker := walker.New[ConflictID]() - - processBranchAndQueueConflictSets := func(branchID BranchID) { - if !traversedBranches.Add(branchID) { - return - } - - s.Branch(branchID).Consume(func(branch *Branch) { - _ = branch.Conflicts().ForEach(func(conflictID ConflictID) (err error) { - conflictSetsWalker.Push(conflictID) - return nil - }) - }) - } - - processBranchAndQueueConflictSets(branchID) - - for conflictSetsWalker.HasNext() { - s.ConflictMembers(conflictSetsWalker.Next()).Consume(func(conflictMember *ConflictMember) { - processBranchAndQueueConflictSets(conflictMember.BranchID()) - }) - } - - traversedBranches.ForEach(func(element BranchID) { - callback(element) - }) -} - // Prune resets the database and deletes all objects (for testing or "node resets"). func (s *Storage) Prune() (err error) { for _, storagePrune := range []func() error{ diff --git a/packages/ledger/branchdag/types.go b/packages/ledger/branchdag/types.go index 2440e1e865..e86cde0ad3 100644 --- a/packages/ledger/branchdag/types.go +++ b/packages/ledger/branchdag/types.go @@ -15,8 +15,6 @@ import ( // region BranchID ///////////////////////////////////////////////////////////////////////////////////////////////////// -const BranchIDLength = types.IdentifierLength - type BranchID struct { types.Identifier } @@ -37,6 +35,8 @@ func (t BranchID) String() (humanReadable string) { var MasterBranchID BranchID +const BranchIDLength = types.IdentifierLength + // endregion /////////////////////////////////////////////////////////////////////////////////////////////////////////// // region BranchIDs //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/packages/ledger/branchdag/utils.go b/packages/ledger/branchdag/utils.go index 2435266a1f..0756acaa84 100644 --- a/packages/ledger/branchdag/utils.go +++ b/packages/ledger/branchdag/utils.go @@ -1,5 +1,11 @@ package branchdag +import ( + "github.com/cockroachdb/errors" + "github.com/iotaledger/hive.go/generics/set" + "github.com/iotaledger/hive.go/generics/walker" +) + type utils struct { branchDAG *BranchDAG } @@ -9,3 +15,58 @@ func newUtil(branchDAG *BranchDAG) (new *utils) { branchDAG: branchDAG, } } + +// ForEachConflictingBranchID executes the callback for each Branch that is conflicting with the Branch +// identified by the given BranchID. +func (u *utils) ForEachConflictingBranchID(branchID BranchID, callback func(conflictingBranchID BranchID) bool) { + abort := false + u.branchDAG.Storage.Branch(branchID).Consume(func(branch *Branch) { + _ = branch.Conflicts().ForEach(func(conflictID ConflictID) (err error) { + u.branchDAG.Storage.ConflictMembers(conflictID).Consume(func(conflictMember *ConflictMember) { + if abort || conflictMember.BranchID() == branchID { + return + } + + abort = !callback(conflictMember.BranchID()) + }) + + if abort { + return errors.New("abort") + } + + return nil + }) + }) +} + +// ForEachConnectedConflictingBranchID executes the callback for each Branch that is connected through a chain +// of intersecting ConflictSets. +func (u *utils) ForEachConnectedConflictingBranchID(branchID BranchID, callback func(conflictingBranchID BranchID)) { + traversedBranches := set.New[BranchID]() + conflictSetsWalker := walker.New[ConflictID]() + + processBranchAndQueueConflictSets := func(branchID BranchID) { + if !traversedBranches.Add(branchID) { + return + } + + u.branchDAG.Storage.Branch(branchID).Consume(func(branch *Branch) { + _ = branch.Conflicts().ForEach(func(conflictID ConflictID) (err error) { + conflictSetsWalker.Push(conflictID) + return nil + }) + }) + } + + processBranchAndQueueConflictSets(branchID) + + for conflictSetsWalker.HasNext() { + u.branchDAG.Storage.ConflictMembers(conflictSetsWalker.Next()).Consume(func(conflictMember *ConflictMember) { + processBranchAndQueueConflictSets(conflictMember.BranchID()) + }) + } + + traversedBranches.ForEach(func(element BranchID) { + callback(element) + }) +}