Skip to content

Commit d6c82c7

Browse files
scottconstablecuviper
authored andcommittedOct 14, 2020
[X86] Fix for ballooning compile times due to Load Value Injection (LVI) mitigations
Fix for the issue raised in rust-lang/rust#74632. The current heuristic for inserting LFENCEs uses a quadratic-time algorithm. This can apparently cause substantial compilation slowdowns for building Rust projects, where functions > 5000 LoC are apparently common. The updated heuristic in this patch implements a linear-time algorithm. On a set of benchmarks, the slowdown factor for the generated code was comparable (2.55x geo mean for the quadratic-time heuristic, vs. 2.58x for the linear-time heuristic). Both heuristics offer the same security properties, namely, mitigating LVI. This patch also includes some formatting fixes. Differential Revision: https://reviews.llvm.org/D84471
1 parent 03940cd commit d6c82c7

File tree

1 file changed

+87
-93
lines changed

1 file changed

+87
-93
lines changed
 

‎llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp

+87-93
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "X86TargetMachine.h"
4343
#include "llvm/ADT/DenseMap.h"
4444
#include "llvm/ADT/DenseSet.h"
45+
#include "llvm/ADT/STLExtras.h"
4546
#include "llvm/ADT/SmallSet.h"
4647
#include "llvm/ADT/Statistic.h"
4748
#include "llvm/ADT/StringRef.h"
@@ -104,9 +105,9 @@ static cl::opt<bool> EmitDotVerify(
104105
cl::init(false), cl::Hidden);
105106

106107
static llvm::sys::DynamicLibrary OptimizeDL;
107-
typedef int (*OptimizeCutT)(unsigned int *nodes, unsigned int nodes_size,
108-
unsigned int *edges, int *edge_values,
109-
int *cut_edges /* out */, unsigned int edges_size);
108+
typedef int (*OptimizeCutT)(unsigned int *Nodes, unsigned int NodesSize,
109+
unsigned int *Edges, int *EdgeValues,
110+
int *CutEdges /* out */, unsigned int EdgesSize);
110111
static OptimizeCutT OptimizeCut = nullptr;
111112

112113
namespace {
@@ -148,9 +149,10 @@ class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass {
148149

149150
private:
150151
using GraphBuilder = ImmutableGraphBuilder<MachineGadgetGraph>;
152+
using Edge = MachineGadgetGraph::Edge;
153+
using Node = MachineGadgetGraph::Node;
151154
using EdgeSet = MachineGadgetGraph::EdgeSet;
152155
using NodeSet = MachineGadgetGraph::NodeSet;
153-
using Gadget = std::pair<MachineInstr *, MachineInstr *>;
154156

155157
const X86Subtarget *STI;
156158
const TargetInstrInfo *TII;
@@ -162,8 +164,8 @@ class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass {
162164
const MachineDominanceFrontier &MDF) const;
163165
int hardenLoadsWithPlugin(MachineFunction &MF,
164166
std::unique_ptr<MachineGadgetGraph> Graph) const;
165-
int hardenLoadsWithGreedyHeuristic(
166-
MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const;
167+
int hardenLoadsWithHeuristic(MachineFunction &MF,
168+
std::unique_ptr<MachineGadgetGraph> Graph) const;
167169
int elimMitigatedEdgesAndNodes(MachineGadgetGraph &G,
168170
EdgeSet &ElimEdges /* in, out */,
169171
NodeSet &ElimNodes /* in, out */) const;
@@ -198,7 +200,7 @@ struct DOTGraphTraits<MachineGadgetGraph *> : DefaultDOTGraphTraits {
198200
using ChildIteratorType = typename Traits::ChildIteratorType;
199201
using ChildEdgeIteratorType = typename Traits::ChildEdgeIteratorType;
200202

201-
DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
203+
DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {}
202204

203205
std::string getNodeLabel(NodeRef Node, GraphType *) {
204206
if (Node->getValue() == MachineGadgetGraph::ArgNodeSentinel)
@@ -243,7 +245,7 @@ void X86LoadValueInjectionLoadHardeningPass::getAnalysisUsage(
243245
AU.setPreservesCFG();
244246
}
245247

246-
static void WriteGadgetGraph(raw_ostream &OS, MachineFunction &MF,
248+
static void writeGadgetGraph(raw_ostream &OS, MachineFunction &MF,
247249
MachineGadgetGraph *G) {
248250
WriteGraph(OS, G, /*ShortNames*/ false,
249251
"Speculative gadgets for \"" + MF.getName() + "\" function");
@@ -279,7 +281,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
279281
return false; // didn't find any gadgets
280282

281283
if (EmitDotVerify) {
282-
WriteGadgetGraph(outs(), MF, Graph.get());
284+
writeGadgetGraph(outs(), MF, Graph.get());
283285
return false;
284286
}
285287

@@ -292,7 +294,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
292294
raw_fd_ostream FileOut(FileName, FileError);
293295
if (FileError)
294296
errs() << FileError.message();
295-
WriteGadgetGraph(FileOut, MF, Graph.get());
297+
writeGadgetGraph(FileOut, MF, Graph.get());
296298
FileOut.close();
297299
LLVM_DEBUG(dbgs() << "Emitting gadget graph... Done\n");
298300
if (EmitDotOnly)
@@ -313,7 +315,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
313315
}
314316
FencesInserted = hardenLoadsWithPlugin(MF, std::move(Graph));
315317
} else { // Use the default greedy heuristic
316-
FencesInserted = hardenLoadsWithGreedyHeuristic(MF, std::move(Graph));
318+
FencesInserted = hardenLoadsWithHeuristic(MF, std::move(Graph));
317319
}
318320

319321
if (FencesInserted > 0)
@@ -540,47 +542,46 @@ X86LoadValueInjectionLoadHardeningPass::getGadgetGraph(
540542

541543
// Returns the number of remaining gadget edges that could not be eliminated
542544
int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes(
543-
MachineGadgetGraph &G, MachineGadgetGraph::EdgeSet &ElimEdges /* in, out */,
544-
MachineGadgetGraph::NodeSet &ElimNodes /* in, out */) const {
545+
MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */,
546+
NodeSet &ElimNodes /* in, out */) const {
545547
if (G.NumFences > 0) {
546548
// Eliminate fences and CFG edges that ingress and egress the fence, as
547549
// they are trivially mitigated.
548-
for (const auto &E : G.edges()) {
549-
const MachineGadgetGraph::Node *Dest = E.getDest();
550+
for (const Edge &E : G.edges()) {
551+
const Node *Dest = E.getDest();
550552
if (isFence(Dest->getValue())) {
551553
ElimNodes.insert(*Dest);
552554
ElimEdges.insert(E);
553-
for (const auto &DE : Dest->edges())
555+
for (const Edge &DE : Dest->edges())
554556
ElimEdges.insert(DE);
555557
}
556558
}
557559
}
558560

559561
// Find and eliminate gadget edges that have been mitigated.
560562
int MitigatedGadgets = 0, RemainingGadgets = 0;
561-
MachineGadgetGraph::NodeSet ReachableNodes{G};
562-
for (const auto &RootN : G.nodes()) {
563+
NodeSet ReachableNodes{G};
564+
for (const Node &RootN : G.nodes()) {
563565
if (llvm::none_of(RootN.edges(), MachineGadgetGraph::isGadgetEdge))
564566
continue; // skip this node if it isn't a gadget source
565567

566568
// Find all of the nodes that are CFG-reachable from RootN using DFS
567569
ReachableNodes.clear();
568-
std::function<void(const MachineGadgetGraph::Node *, bool)>
569-
FindReachableNodes =
570-
[&](const MachineGadgetGraph::Node *N, bool FirstNode) {
571-
if (!FirstNode)
572-
ReachableNodes.insert(*N);
573-
for (const auto &E : N->edges()) {
574-
const MachineGadgetGraph::Node *Dest = E.getDest();
575-
if (MachineGadgetGraph::isCFGEdge(E) &&
576-
!ElimEdges.contains(E) && !ReachableNodes.contains(*Dest))
577-
FindReachableNodes(Dest, false);
578-
}
579-
};
570+
std::function<void(const Node *, bool)> FindReachableNodes =
571+
[&](const Node *N, bool FirstNode) {
572+
if (!FirstNode)
573+
ReachableNodes.insert(*N);
574+
for (const Edge &E : N->edges()) {
575+
const Node *Dest = E.getDest();
576+
if (MachineGadgetGraph::isCFGEdge(E) && !ElimEdges.contains(E) &&
577+
!ReachableNodes.contains(*Dest))
578+
FindReachableNodes(Dest, false);
579+
}
580+
};
580581
FindReachableNodes(&RootN, true);
581582

582583
// Any gadget whose sink is unreachable has been mitigated
583-
for (const auto &E : RootN.edges()) {
584+
for (const Edge &E : RootN.edges()) {
584585
if (MachineGadgetGraph::isGadgetEdge(E)) {
585586
if (ReachableNodes.contains(*E.getDest())) {
586587
// This gadget's sink is reachable
@@ -598,8 +599,8 @@ int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes(
598599
std::unique_ptr<MachineGadgetGraph>
599600
X86LoadValueInjectionLoadHardeningPass::trimMitigatedEdges(
600601
std::unique_ptr<MachineGadgetGraph> Graph) const {
601-
MachineGadgetGraph::NodeSet ElimNodes{*Graph};
602-
MachineGadgetGraph::EdgeSet ElimEdges{*Graph};
602+
NodeSet ElimNodes{*Graph};
603+
EdgeSet ElimEdges{*Graph};
603604
int RemainingGadgets =
604605
elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes);
605606
if (ElimEdges.empty() && ElimNodes.empty()) {
@@ -630,11 +631,11 @@ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin(
630631
auto Edges = std::make_unique<unsigned int[]>(Graph->edges_size());
631632
auto EdgeCuts = std::make_unique<int[]>(Graph->edges_size());
632633
auto EdgeValues = std::make_unique<int[]>(Graph->edges_size());
633-
for (const auto &N : Graph->nodes()) {
634+
for (const Node &N : Graph->nodes()) {
634635
Nodes[Graph->getNodeIndex(N)] = Graph->getEdgeIndex(*N.edges_begin());
635636
}
636637
Nodes[Graph->nodes_size()] = Graph->edges_size(); // terminator node
637-
for (const auto &E : Graph->edges()) {
638+
for (const Edge &E : Graph->edges()) {
638639
Edges[Graph->getEdgeIndex(E)] = Graph->getNodeIndex(*E.getDest());
639640
EdgeValues[Graph->getEdgeIndex(E)] = E.getValue();
640641
}
@@ -651,74 +652,67 @@ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin(
651652
LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n");
652653
LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted << " fences\n");
653654

654-
Graph = GraphBuilder::trim(*Graph, MachineGadgetGraph::NodeSet{*Graph},
655-
CutEdges);
655+
Graph = GraphBuilder::trim(*Graph, NodeSet{*Graph}, CutEdges);
656656
} while (true);
657657

658658
return FencesInserted;
659659
}
660660

661-
int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithGreedyHeuristic(
661+
int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic(
662662
MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const {
663-
LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n");
664-
Graph = trimMitigatedEdges(std::move(Graph));
665-
LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n");
663+
// If `MF` does not have any fences, then no gadgets would have been
664+
// mitigated at this point.
665+
if (Graph->NumFences > 0) {
666+
LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n");
667+
Graph = trimMitigatedEdges(std::move(Graph));
668+
LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n");
669+
}
670+
666671
if (Graph->NumGadgets == 0)
667672
return 0;
668673

669674
LLVM_DEBUG(dbgs() << "Cutting edges...\n");
670-
MachineGadgetGraph::NodeSet ElimNodes{*Graph}, GadgetSinks{*Graph};
671-
MachineGadgetGraph::EdgeSet ElimEdges{*Graph}, CutEdges{*Graph};
672-
auto IsCFGEdge = [&ElimEdges, &CutEdges](const MachineGadgetGraph::Edge &E) {
673-
return !ElimEdges.contains(E) && !CutEdges.contains(E) &&
674-
MachineGadgetGraph::isCFGEdge(E);
675-
};
676-
auto IsGadgetEdge = [&ElimEdges,
677-
&CutEdges](const MachineGadgetGraph::Edge &E) {
678-
return !ElimEdges.contains(E) && !CutEdges.contains(E) &&
679-
MachineGadgetGraph::isGadgetEdge(E);
680-
};
681-
682-
// FIXME: this is O(E^2), we could probably do better.
683-
do {
684-
// Find the cheapest CFG edge that will eliminate a gadget (by being
685-
// egress from a SOURCE node or ingress to a SINK node), and cut it.
686-
const MachineGadgetGraph::Edge *CheapestSoFar = nullptr;
687-
688-
// First, collect all gadget source and sink nodes.
689-
MachineGadgetGraph::NodeSet GadgetSources{*Graph}, GadgetSinks{*Graph};
690-
for (const auto &N : Graph->nodes()) {
691-
if (ElimNodes.contains(N))
675+
EdgeSet CutEdges{*Graph};
676+
677+
// Begin by collecting all ingress CFG edges for each node
678+
DenseMap<const Node *, SmallVector<const Edge *, 2>> IngressEdgeMap;
679+
for (const Edge &E : Graph->edges())
680+
if (MachineGadgetGraph::isCFGEdge(E))
681+
IngressEdgeMap[E.getDest()].push_back(&E);
682+
683+
// For each gadget edge, make cuts that guarantee the gadget will be
684+
// mitigated. A computationally efficient way to achieve this is to either:
685+
// (a) cut all egress CFG edges from the gadget source, or
686+
// (b) cut all ingress CFG edges to the gadget sink.
687+
//
688+
// Moreover, the algorithm tries not to make a cut into a loop by preferring
689+
// to make a (b)-type cut if the gadget source resides at a greater loop depth
690+
// than the gadget sink, or an (a)-type cut otherwise.
691+
for (const Node &N : Graph->nodes()) {
692+
for (const Edge &E : N.edges()) {
693+
if (!MachineGadgetGraph::isGadgetEdge(E))
692694
continue;
693-
for (const auto &E : N.edges()) {
694-
if (IsGadgetEdge(E)) {
695-
GadgetSources.insert(N);
696-
GadgetSinks.insert(*E.getDest());
697-
}
698-
}
699-
}
700695

701-
// Next, look for the cheapest CFG edge which, when cut, is guaranteed to
702-
// mitigate at least one gadget by either:
703-
// (a) being egress from a gadget source, or
704-
// (b) being ingress to a gadget sink.
705-
for (const auto &N : Graph->nodes()) {
706-
if (ElimNodes.contains(N))
707-
continue;
708-
for (const auto &E : N.edges()) {
709-
if (IsCFGEdge(E)) {
710-
if (GadgetSources.contains(N) || GadgetSinks.contains(*E.getDest())) {
711-
if (!CheapestSoFar || E.getValue() < CheapestSoFar->getValue())
712-
CheapestSoFar = &E;
713-
}
714-
}
715-
}
696+
SmallVector<const Edge *, 2> EgressEdges;
697+
SmallVector<const Edge *, 2> &IngressEdges = IngressEdgeMap[E.getDest()];
698+
for (const Edge &EgressEdge : N.edges())
699+
if (MachineGadgetGraph::isCFGEdge(EgressEdge))
700+
EgressEdges.push_back(&EgressEdge);
701+
702+
int EgressCutCost = 0, IngressCutCost = 0;
703+
for (const Edge *EgressEdge : EgressEdges)
704+
if (!CutEdges.contains(*EgressEdge))
705+
EgressCutCost += EgressEdge->getValue();
706+
for (const Edge *IngressEdge : IngressEdges)
707+
if (!CutEdges.contains(*IngressEdge))
708+
IngressCutCost += IngressEdge->getValue();
709+
710+
auto &EdgesToCut =
711+
IngressCutCost < EgressCutCost ? IngressEdges : EgressEdges;
712+
for (const Edge *E : EdgesToCut)
713+
CutEdges.insert(*E);
716714
}
717-
718-
assert(CheapestSoFar && "Failed to cut an edge");
719-
CutEdges.insert(*CheapestSoFar);
720-
ElimEdges.insert(*CheapestSoFar);
721-
} while (elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes));
715+
}
722716
LLVM_DEBUG(dbgs() << "Cutting edges... Done\n");
723717
LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n");
724718

@@ -734,8 +728,8 @@ int X86LoadValueInjectionLoadHardeningPass::insertFences(
734728
MachineFunction &MF, MachineGadgetGraph &G,
735729
EdgeSet &CutEdges /* in, out */) const {
736730
int FencesInserted = 0;
737-
for (const auto &N : G.nodes()) {
738-
for (const auto &E : N.edges()) {
731+
for (const Node &N : G.nodes()) {
732+
for (const Edge &E : N.edges()) {
739733
if (CutEdges.contains(E)) {
740734
MachineInstr *MI = N.getValue(), *Prev;
741735
MachineBasicBlock *MBB; // Insert an LFENCE in this MBB
@@ -751,7 +745,7 @@ int X86LoadValueInjectionLoadHardeningPass::insertFences(
751745
Prev = MI->getPrevNode();
752746
// Remove all egress CFG edges from this branch because the inserted
753747
// LFENCE prevents gadgets from crossing the branch.
754-
for (const auto &E : N.edges()) {
748+
for (const Edge &E : N.edges()) {
755749
if (MachineGadgetGraph::isCFGEdge(E))
756750
CutEdges.insert(E);
757751
}

0 commit comments

Comments
 (0)