Skip to content

Commit

Permalink
Address comments from Ellis
Browse files Browse the repository at this point in the history
  • Loading branch information
kyulee-com committed May 5, 2024
1 parent 70c3d08 commit c05bdf6
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 42 deletions.
29 changes: 10 additions & 19 deletions llvm/include/llvm/CodeGenData/OutlinedHashTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,22 @@ namespace llvm {
/// a hash sequence with that occurrence count.
struct HashNode {
/// The hash value of the node.
stable_hash Hash;
stable_hash Hash = 0;
/// The number of terminals in the sequence ending at this node.
unsigned Terminals;
std::optional<unsigned> Terminals;
/// The successors of this node.
/// We don't use DenseMap as a stable_hash value can be tombstone.
std::unordered_map<stable_hash, std::unique_ptr<HashNode>> Successors;
};

/// HashNodeStable is the serialized, stable, and compact representation
/// of a HashNode.
struct HashNodeStable {
llvm::yaml::Hex64 Hash;
unsigned Terminals;
std::vector<unsigned> SuccessorIds;
};

class OutlinedHashTree {

using EdgeCallbackFn =
std::function<void(const HashNode *, const HashNode *)>;
using NodeCallbackFn = std::function<void(const HashNode *)>;

using HashSequence = std::vector<stable_hash>;
using HashSequencePair = std::pair<std::vector<stable_hash>, unsigned>;
using HashSequence = SmallVector<stable_hash>;
using HashSequencePair = std::pair<HashSequence, unsigned>;

public:
/// Walks every edge and node in the OutlinedHashTree and calls CallbackEdge
Expand All @@ -66,7 +59,7 @@ class OutlinedHashTree {

/// Release all hash nodes except the root hash node.
void clear() {
assert(getRoot()->Hash == 0 && getRoot()->Terminals == 0);
assert(getRoot()->Hash == 0 && !getRoot()->Terminals);
getRoot()->Successors.clear();
}

Expand All @@ -83,8 +76,8 @@ class OutlinedHashTree {
size_t depth() const;

/// \returns the root hash node of a OutlinedHashTree.
const HashNode *getRoot() const { return Root.get(); }
HashNode *getRoot() { return Root.get(); }
const HashNode *getRoot() const { return &Root; }
HashNode *getRoot() { return &Root; }

/// Inserts a \p Sequence into the this tree. The last node in the sequence
/// will increase Terminals.
Expand All @@ -94,12 +87,10 @@ class OutlinedHashTree {
void merge(const OutlinedHashTree *OtherTree);

/// \returns the matching count if \p Sequence exists in the OutlinedHashTree.
unsigned find(const HashSequence &Sequence) const;

OutlinedHashTree() { Root = std::make_unique<HashNode>(); }
std::optional<unsigned> find(const HashSequence &Sequence) const;

private:
std::unique_ptr<HashNode> Root;
HashNode Root;
};

} // namespace llvm
Expand Down
13 changes: 11 additions & 2 deletions llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,22 @@
#ifndef LLVM_CODEGENDATA_OUTLINEDHASHTREERECORD_H
#define LLVM_CODEGENDATA_OUTLINEDHASHTREERECORD_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/CodeGenData/OutlinedHashTree.h"

namespace llvm {

/// HashNodeStable is the serialized, stable, and compact representation
/// of a HashNode.
struct HashNodeStable {
llvm::yaml::Hex64 Hash;
unsigned Terminals;
std::vector<unsigned> SuccessorIds;
};

using IdHashNodeStableMapTy = std::map<unsigned, HashNodeStable>;
using IdHashNodeMapTy = std::map<unsigned, HashNode *>;
using HashNodeIdMapTy = std::unordered_map<const HashNode *, unsigned>;
using IdHashNodeMapTy = DenseMap<unsigned, HashNode *>;
using HashNodeIdMapTy = DenseMap<const HashNode *, unsigned>;

struct OutlinedHashTreeRecord {
std::unique_ptr<OutlinedHashTree> HashTree;
Expand Down
32 changes: 16 additions & 16 deletions llvm/lib/CodeGenData/OutlinedHashTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,18 @@ using namespace llvm;
void OutlinedHashTree::walkGraph(NodeCallbackFn CallbackNode,
EdgeCallbackFn CallbackEdge,
bool SortedWalk) const {
std::stack<const HashNode *> Stack;
Stack.push(getRoot());
SmallVector<const HashNode *> Stack;
Stack.emplace_back(getRoot());

while (!Stack.empty()) {
const auto *Current = Stack.top();
Stack.pop();
const auto *Current = Stack.pop_back_val();
if (CallbackNode)
CallbackNode(Current);

auto HandleNext = [&](const HashNode *Next) {
if (CallbackEdge)
CallbackEdge(Current, Next);
Stack.push(Next);
Stack.emplace_back(Next);
};
if (SortedWalk) {
std::map<stable_hash, const HashNode *> SortedSuccessors;
Expand Down Expand Up @@ -72,8 +71,7 @@ size_t OutlinedHashTree::depth() const {
}

void OutlinedHashTree::insert(const HashSequencePair &SequencePair) {
const auto &Sequence = SequencePair.first;
unsigned Count = SequencePair.second;
auto &[Sequence, Count] = SequencePair;
HashNode *Current = getRoot();

for (stable_hash StableHash : Sequence) {
Expand All @@ -87,22 +85,23 @@ void OutlinedHashTree::insert(const HashSequencePair &SequencePair) {
} else
Current = I->second.get();
}
Current->Terminals += Count;
if (Count)
Current->Terminals = (Current->Terminals ? *Current->Terminals : 0) + Count;
}

void OutlinedHashTree::merge(const OutlinedHashTree *Tree) {
HashNode *Dst = getRoot();
const HashNode *Src = Tree->getRoot();
std::stack<std::pair<HashNode *, const HashNode *>> Stack;
Stack.push({Dst, Src});
SmallVector<std::pair<HashNode *, const HashNode *>> Stack;
Stack.emplace_back(Dst, Src);

while (!Stack.empty()) {
auto [DstNode, SrcNode] = Stack.top();
Stack.pop();
auto [DstNode, SrcNode] = Stack.pop_back_val();
if (!SrcNode)
continue;
DstNode->Terminals += SrcNode->Terminals;

if (SrcNode->Terminals)
DstNode->Terminals =
(DstNode->Terminals ? *DstNode->Terminals : 0) + *SrcNode->Terminals;
for (auto &[Hash, NextSrcNode] : SrcNode->Successors) {
HashNode *NextDstNode;
auto I = DstNode->Successors.find(Hash);
Expand All @@ -114,12 +113,13 @@ void OutlinedHashTree::merge(const OutlinedHashTree *Tree) {
} else
NextDstNode = I->second.get();

Stack.push({NextDstNode, NextSrcNode.get()});
Stack.emplace_back(NextDstNode, NextSrcNode.get());
}
}
}

unsigned OutlinedHashTree::find(const HashSequence &Sequence) const {
std::optional<unsigned>
OutlinedHashTree::find(const HashSequence &Sequence) const {
const HashNode *Current = getRoot();
for (stable_hash StableHash : Sequence) {
const auto I = Current->Successors.find(StableHash);
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ void OutlinedHashTreeRecord::convertToStableData(
auto Id = P.second;
HashNodeStable NodeStable;
NodeStable.Hash = Node->Hash;
NodeStable.Terminals = Node->Terminals;
NodeStable.Terminals = Node->Terminals ? *Node->Terminals : 0;
for (auto &P : Node->Successors)
NodeStable.SuccessorIds.push_back(NodeIdMap[P.second.get()]);
IdNodeStableMap[Id] = NodeStable;
}

// Sort the Successors so that they come out in the same order as in the map.
for (auto &P : IdNodeStableMap)
std::sort(P.second.SuccessorIds.begin(), P.second.SuccessorIds.end());
llvm::sort(P.second.SuccessorIds);
}

void OutlinedHashTreeRecord::convertFromStableData(
Expand All @@ -155,7 +155,8 @@ void OutlinedHashTreeRecord::convertFromStableData(
assert(IdNodeMap.count(Id));
HashNode *Curr = IdNodeMap[Id];
Curr->Hash = NodeStable.Hash;
Curr->Terminals = NodeStable.Terminals;
if (NodeStable.Terminals)
Curr->Terminals = NodeStable.Terminals;
auto &Successors = Curr->Successors;
assert(Successors.empty());
for (auto SuccessorId : NodeStable.SuccessorIds) {
Expand Down
5 changes: 3 additions & 2 deletions llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ TEST(OutlinedHashTreeTest, Find) {
// The node count does not change as the same sequences are added.
ASSERT_TRUE(HashTree.size() == 4);
// The terminal counts are accumulated from two same sequences.
ASSERT_TRUE(HashTree.find({1, 2, 3}) == 3);
ASSERT_TRUE(HashTree.find({1, 2}) == 0);
ASSERT_TRUE(HashTree.find({1, 2, 3}));
ASSERT_TRUE(HashTree.find({1, 2, 3}).value() == 3);
ASSERT_FALSE(HashTree.find({1, 2}));
}

TEST(OutlinedHashTreeTest, Merge) {
Expand Down

0 comments on commit c05bdf6

Please sign in to comment.