diff --git a/llvm/include/llvm/CodeGenData/OutlinedHashTree.h b/llvm/include/llvm/CodeGenData/OutlinedHashTree.h index 875e1a78bb401..c40038cd8c517 100644 --- a/llvm/include/llvm/CodeGenData/OutlinedHashTree.h +++ b/llvm/include/llvm/CodeGenData/OutlinedHashTree.h @@ -30,29 +30,22 @@ namespace llvm { /// a hash sequence with that occurrence count. struct HashNode { /// The hash value of the node. - stable_hash Hash; + stable_hash Hash = 0; /// The number of terminals in the sequence ending at this node. - unsigned Terminals; + std::optional Terminals; /// The successors of this node. + /// We don't use DenseMap as a stable_hash value can be tombstone. std::unordered_map> Successors; }; -/// HashNodeStable is the serialized, stable, and compact representation -/// of a HashNode. -struct HashNodeStable { - llvm::yaml::Hex64 Hash; - unsigned Terminals; - std::vector SuccessorIds; -}; - class OutlinedHashTree { using EdgeCallbackFn = std::function; using NodeCallbackFn = std::function; - using HashSequence = std::vector; - using HashSequencePair = std::pair, unsigned>; + using HashSequence = SmallVector; + using HashSequencePair = std::pair; public: /// Walks every edge and node in the OutlinedHashTree and calls CallbackEdge @@ -66,7 +59,7 @@ class OutlinedHashTree { /// Release all hash nodes except the root hash node. void clear() { - assert(getRoot()->Hash == 0 && getRoot()->Terminals == 0); + assert(getRoot()->Hash == 0 && !getRoot()->Terminals); getRoot()->Successors.clear(); } @@ -83,8 +76,8 @@ class OutlinedHashTree { size_t depth() const; /// \returns the root hash node of a OutlinedHashTree. - const HashNode *getRoot() const { return Root.get(); } - HashNode *getRoot() { return Root.get(); } + const HashNode *getRoot() const { return &Root; } + HashNode *getRoot() { return &Root; } /// Inserts a \p Sequence into the this tree. The last node in the sequence /// will increase Terminals. @@ -94,12 +87,10 @@ class OutlinedHashTree { void merge(const OutlinedHashTree *OtherTree); /// \returns the matching count if \p Sequence exists in the OutlinedHashTree. - unsigned find(const HashSequence &Sequence) const; - - OutlinedHashTree() { Root = std::make_unique(); } + std::optional find(const HashSequence &Sequence) const; private: - std::unique_ptr Root; + HashNode Root; }; } // namespace llvm diff --git a/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h b/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h index ccd2ad26dd087..2960e31960448 100644 --- a/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h +++ b/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h @@ -16,13 +16,22 @@ #ifndef LLVM_CODEGENDATA_OUTLINEDHASHTREERECORD_H #define LLVM_CODEGENDATA_OUTLINEDHASHTREERECORD_H +#include "llvm/ADT/DenseMap.h" #include "llvm/CodeGenData/OutlinedHashTree.h" namespace llvm { +/// HashNodeStable is the serialized, stable, and compact representation +/// of a HashNode. +struct HashNodeStable { + llvm::yaml::Hex64 Hash; + unsigned Terminals; + std::vector SuccessorIds; +}; + using IdHashNodeStableMapTy = std::map; -using IdHashNodeMapTy = std::map; -using HashNodeIdMapTy = std::unordered_map; +using IdHashNodeMapTy = DenseMap; +using HashNodeIdMapTy = DenseMap; struct OutlinedHashTreeRecord { std::unique_ptr HashTree; diff --git a/llvm/lib/CodeGenData/OutlinedHashTree.cpp b/llvm/lib/CodeGenData/OutlinedHashTree.cpp index 032993ded60ea..cb985aa87afcf 100644 --- a/llvm/lib/CodeGenData/OutlinedHashTree.cpp +++ b/llvm/lib/CodeGenData/OutlinedHashTree.cpp @@ -24,19 +24,18 @@ using namespace llvm; void OutlinedHashTree::walkGraph(NodeCallbackFn CallbackNode, EdgeCallbackFn CallbackEdge, bool SortedWalk) const { - std::stack Stack; - Stack.push(getRoot()); + SmallVector Stack; + Stack.emplace_back(getRoot()); while (!Stack.empty()) { - const auto *Current = Stack.top(); - Stack.pop(); + const auto *Current = Stack.pop_back_val(); if (CallbackNode) CallbackNode(Current); auto HandleNext = [&](const HashNode *Next) { if (CallbackEdge) CallbackEdge(Current, Next); - Stack.push(Next); + Stack.emplace_back(Next); }; if (SortedWalk) { std::map SortedSuccessors; @@ -72,8 +71,7 @@ size_t OutlinedHashTree::depth() const { } void OutlinedHashTree::insert(const HashSequencePair &SequencePair) { - const auto &Sequence = SequencePair.first; - unsigned Count = SequencePair.second; + auto &[Sequence, Count] = SequencePair; HashNode *Current = getRoot(); for (stable_hash StableHash : Sequence) { @@ -87,22 +85,23 @@ void OutlinedHashTree::insert(const HashSequencePair &SequencePair) { } else Current = I->second.get(); } - Current->Terminals += Count; + if (Count) + Current->Terminals = (Current->Terminals ? *Current->Terminals : 0) + Count; } void OutlinedHashTree::merge(const OutlinedHashTree *Tree) { HashNode *Dst = getRoot(); const HashNode *Src = Tree->getRoot(); - std::stack> Stack; - Stack.push({Dst, Src}); + SmallVector> Stack; + Stack.emplace_back(Dst, Src); while (!Stack.empty()) { - auto [DstNode, SrcNode] = Stack.top(); - Stack.pop(); + auto [DstNode, SrcNode] = Stack.pop_back_val(); if (!SrcNode) continue; - DstNode->Terminals += SrcNode->Terminals; - + if (SrcNode->Terminals) + DstNode->Terminals = + (DstNode->Terminals ? *DstNode->Terminals : 0) + *SrcNode->Terminals; for (auto &[Hash, NextSrcNode] : SrcNode->Successors) { HashNode *NextDstNode; auto I = DstNode->Successors.find(Hash); @@ -114,12 +113,13 @@ void OutlinedHashTree::merge(const OutlinedHashTree *Tree) { } else NextDstNode = I->second.get(); - Stack.push({NextDstNode, NextSrcNode.get()}); + Stack.emplace_back(NextDstNode, NextSrcNode.get()); } } } -unsigned OutlinedHashTree::find(const HashSequence &Sequence) const { +std::optional +OutlinedHashTree::find(const HashSequence &Sequence) const { const HashNode *Current = getRoot(); for (stable_hash StableHash : Sequence) { const auto I = Current->Successors.find(StableHash); diff --git a/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp b/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp index 0d5dd864c89c5..da4db7e9e69f1 100644 --- a/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp +++ b/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp @@ -131,7 +131,7 @@ void OutlinedHashTreeRecord::convertToStableData( auto Id = P.second; HashNodeStable NodeStable; NodeStable.Hash = Node->Hash; - NodeStable.Terminals = Node->Terminals; + NodeStable.Terminals = Node->Terminals ? *Node->Terminals : 0; for (auto &P : Node->Successors) NodeStable.SuccessorIds.push_back(NodeIdMap[P.second.get()]); IdNodeStableMap[Id] = NodeStable; @@ -139,7 +139,7 @@ void OutlinedHashTreeRecord::convertToStableData( // Sort the Successors so that they come out in the same order as in the map. for (auto &P : IdNodeStableMap) - std::sort(P.second.SuccessorIds.begin(), P.second.SuccessorIds.end()); + llvm::sort(P.second.SuccessorIds); } void OutlinedHashTreeRecord::convertFromStableData( @@ -155,7 +155,8 @@ void OutlinedHashTreeRecord::convertFromStableData( assert(IdNodeMap.count(Id)); HashNode *Curr = IdNodeMap[Id]; Curr->Hash = NodeStable.Hash; - Curr->Terminals = NodeStable.Terminals; + if (NodeStable.Terminals) + Curr->Terminals = NodeStable.Terminals; auto &Successors = Curr->Successors; assert(Successors.empty()); for (auto SuccessorId : NodeStable.SuccessorIds) { diff --git a/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp b/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp index d11618cf8e4fa..5fdfa60673b7f 100644 --- a/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp +++ b/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp @@ -50,8 +50,9 @@ TEST(OutlinedHashTreeTest, Find) { // The node count does not change as the same sequences are added. ASSERT_TRUE(HashTree.size() == 4); // The terminal counts are accumulated from two same sequences. - ASSERT_TRUE(HashTree.find({1, 2, 3}) == 3); - ASSERT_TRUE(HashTree.find({1, 2}) == 0); + ASSERT_TRUE(HashTree.find({1, 2, 3})); + ASSERT_TRUE(HashTree.find({1, 2, 3}).value() == 3); + ASSERT_FALSE(HashTree.find({1, 2})); } TEST(OutlinedHashTreeTest, Merge) {