42
42
#include " X86TargetMachine.h"
43
43
#include " llvm/ADT/DenseMap.h"
44
44
#include " llvm/ADT/DenseSet.h"
45
+ #include " llvm/ADT/STLExtras.h"
45
46
#include " llvm/ADT/SmallSet.h"
46
47
#include " llvm/ADT/Statistic.h"
47
48
#include " llvm/ADT/StringRef.h"
@@ -104,9 +105,9 @@ static cl::opt<bool> EmitDotVerify(
104
105
cl::init(false ), cl::Hidden);
105
106
106
107
static llvm::sys::DynamicLibrary OptimizeDL;
107
- typedef int (*OptimizeCutT)(unsigned int *nodes , unsigned int nodes_size ,
108
- unsigned int *edges , int *edge_values ,
109
- int *cut_edges /* out */ , unsigned int edges_size );
108
+ typedef int (*OptimizeCutT)(unsigned int *Nodes , unsigned int NodesSize ,
109
+ unsigned int *Edges , int *EdgeValues ,
110
+ int *CutEdges /* out */ , unsigned int EdgesSize );
110
111
static OptimizeCutT OptimizeCut = nullptr ;
111
112
112
113
namespace {
@@ -148,9 +149,10 @@ class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass {
148
149
149
150
private:
150
151
using GraphBuilder = ImmutableGraphBuilder<MachineGadgetGraph>;
152
+ using Edge = MachineGadgetGraph::Edge;
153
+ using Node = MachineGadgetGraph::Node;
151
154
using EdgeSet = MachineGadgetGraph::EdgeSet;
152
155
using NodeSet = MachineGadgetGraph::NodeSet;
153
- using Gadget = std::pair<MachineInstr *, MachineInstr *>;
154
156
155
157
const X86Subtarget *STI;
156
158
const TargetInstrInfo *TII;
@@ -162,8 +164,8 @@ class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass {
162
164
const MachineDominanceFrontier &MDF) const ;
163
165
int hardenLoadsWithPlugin (MachineFunction &MF,
164
166
std::unique_ptr<MachineGadgetGraph> Graph) const ;
165
- int hardenLoadsWithGreedyHeuristic (
166
- MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const ;
167
+ int hardenLoadsWithHeuristic (MachineFunction &MF,
168
+ std::unique_ptr<MachineGadgetGraph> Graph) const ;
167
169
int elimMitigatedEdgesAndNodes (MachineGadgetGraph &G,
168
170
EdgeSet &ElimEdges /* in, out */ ,
169
171
NodeSet &ElimNodes /* in, out */ ) const ;
@@ -198,7 +200,7 @@ struct DOTGraphTraits<MachineGadgetGraph *> : DefaultDOTGraphTraits {
198
200
using ChildIteratorType = typename Traits::ChildIteratorType;
199
201
using ChildEdgeIteratorType = typename Traits::ChildEdgeIteratorType;
200
202
201
- DOTGraphTraits (bool isSimple = false ) : DefaultDOTGraphTraits(isSimple ) {}
203
+ DOTGraphTraits (bool IsSimple = false ) : DefaultDOTGraphTraits(IsSimple ) {}
202
204
203
205
std::string getNodeLabel (NodeRef Node, GraphType *) {
204
206
if (Node->getValue () == MachineGadgetGraph::ArgNodeSentinel)
@@ -243,7 +245,7 @@ void X86LoadValueInjectionLoadHardeningPass::getAnalysisUsage(
243
245
AU.setPreservesCFG ();
244
246
}
245
247
246
- static void WriteGadgetGraph (raw_ostream &OS, MachineFunction &MF,
248
+ static void writeGadgetGraph (raw_ostream &OS, MachineFunction &MF,
247
249
MachineGadgetGraph *G) {
248
250
WriteGraph (OS, G, /* ShortNames*/ false ,
249
251
" Speculative gadgets for \" " + MF.getName () + " \" function" );
@@ -279,7 +281,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
279
281
return false ; // didn't find any gadgets
280
282
281
283
if (EmitDotVerify) {
282
- WriteGadgetGraph (outs (), MF, Graph.get ());
284
+ writeGadgetGraph (outs (), MF, Graph.get ());
283
285
return false ;
284
286
}
285
287
@@ -292,7 +294,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
292
294
raw_fd_ostream FileOut (FileName, FileError);
293
295
if (FileError)
294
296
errs () << FileError.message ();
295
- WriteGadgetGraph (FileOut, MF, Graph.get ());
297
+ writeGadgetGraph (FileOut, MF, Graph.get ());
296
298
FileOut.close ();
297
299
LLVM_DEBUG (dbgs () << " Emitting gadget graph... Done\n " );
298
300
if (EmitDotOnly)
@@ -313,7 +315,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
313
315
}
314
316
FencesInserted = hardenLoadsWithPlugin (MF, std::move (Graph));
315
317
} else { // Use the default greedy heuristic
316
- FencesInserted = hardenLoadsWithGreedyHeuristic (MF, std::move (Graph));
318
+ FencesInserted = hardenLoadsWithHeuristic (MF, std::move (Graph));
317
319
}
318
320
319
321
if (FencesInserted > 0 )
@@ -540,47 +542,46 @@ X86LoadValueInjectionLoadHardeningPass::getGadgetGraph(
540
542
541
543
// Returns the number of remaining gadget edges that could not be eliminated
542
544
int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes (
543
- MachineGadgetGraph &G, MachineGadgetGraph:: EdgeSet &ElimEdges /* in, out */ ,
544
- MachineGadgetGraph:: NodeSet &ElimNodes /* in, out */ ) const {
545
+ MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */ ,
546
+ NodeSet &ElimNodes /* in, out */ ) const {
545
547
if (G.NumFences > 0 ) {
546
548
// Eliminate fences and CFG edges that ingress and egress the fence, as
547
549
// they are trivially mitigated.
548
- for (const auto &E : G.edges ()) {
549
- const MachineGadgetGraph:: Node *Dest = E.getDest ();
550
+ for (const Edge &E : G.edges ()) {
551
+ const Node *Dest = E.getDest ();
550
552
if (isFence (Dest->getValue ())) {
551
553
ElimNodes.insert (*Dest);
552
554
ElimEdges.insert (E);
553
- for (const auto &DE : Dest->edges ())
555
+ for (const Edge &DE : Dest->edges ())
554
556
ElimEdges.insert (DE);
555
557
}
556
558
}
557
559
}
558
560
559
561
// Find and eliminate gadget edges that have been mitigated.
560
562
int MitigatedGadgets = 0 , RemainingGadgets = 0 ;
561
- MachineGadgetGraph:: NodeSet ReachableNodes{G};
562
- for (const auto &RootN : G.nodes ()) {
563
+ NodeSet ReachableNodes{G};
564
+ for (const Node &RootN : G.nodes ()) {
563
565
if (llvm::none_of (RootN.edges (), MachineGadgetGraph::isGadgetEdge))
564
566
continue ; // skip this node if it isn't a gadget source
565
567
566
568
// Find all of the nodes that are CFG-reachable from RootN using DFS
567
569
ReachableNodes.clear ();
568
- std::function<void (const MachineGadgetGraph::Node *, bool )>
569
- FindReachableNodes =
570
- [&](const MachineGadgetGraph::Node *N, bool FirstNode) {
571
- if (!FirstNode)
572
- ReachableNodes.insert (*N);
573
- for (const auto &E : N->edges ()) {
574
- const MachineGadgetGraph::Node *Dest = E.getDest ();
575
- if (MachineGadgetGraph::isCFGEdge (E) &&
576
- !ElimEdges.contains (E) && !ReachableNodes.contains (*Dest))
577
- FindReachableNodes (Dest, false );
578
- }
579
- };
570
+ std::function<void (const Node *, bool )> FindReachableNodes =
571
+ [&](const Node *N, bool FirstNode) {
572
+ if (!FirstNode)
573
+ ReachableNodes.insert (*N);
574
+ for (const Edge &E : N->edges ()) {
575
+ const Node *Dest = E.getDest ();
576
+ if (MachineGadgetGraph::isCFGEdge (E) && !ElimEdges.contains (E) &&
577
+ !ReachableNodes.contains (*Dest))
578
+ FindReachableNodes (Dest, false );
579
+ }
580
+ };
580
581
FindReachableNodes (&RootN, true );
581
582
582
583
// Any gadget whose sink is unreachable has been mitigated
583
- for (const auto &E : RootN.edges ()) {
584
+ for (const Edge &E : RootN.edges ()) {
584
585
if (MachineGadgetGraph::isGadgetEdge (E)) {
585
586
if (ReachableNodes.contains (*E.getDest ())) {
586
587
// This gadget's sink is reachable
@@ -598,8 +599,8 @@ int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes(
598
599
std::unique_ptr<MachineGadgetGraph>
599
600
X86LoadValueInjectionLoadHardeningPass::trimMitigatedEdges (
600
601
std::unique_ptr<MachineGadgetGraph> Graph) const {
601
- MachineGadgetGraph:: NodeSet ElimNodes{*Graph};
602
- MachineGadgetGraph:: EdgeSet ElimEdges{*Graph};
602
+ NodeSet ElimNodes{*Graph};
603
+ EdgeSet ElimEdges{*Graph};
603
604
int RemainingGadgets =
604
605
elimMitigatedEdgesAndNodes (*Graph, ElimEdges, ElimNodes);
605
606
if (ElimEdges.empty () && ElimNodes.empty ()) {
@@ -630,11 +631,11 @@ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin(
630
631
auto Edges = std::make_unique<unsigned int []>(Graph->edges_size ());
631
632
auto EdgeCuts = std::make_unique<int []>(Graph->edges_size ());
632
633
auto EdgeValues = std::make_unique<int []>(Graph->edges_size ());
633
- for (const auto &N : Graph->nodes ()) {
634
+ for (const Node &N : Graph->nodes ()) {
634
635
Nodes[Graph->getNodeIndex (N)] = Graph->getEdgeIndex (*N.edges_begin ());
635
636
}
636
637
Nodes[Graph->nodes_size ()] = Graph->edges_size (); // terminator node
637
- for (const auto &E : Graph->edges ()) {
638
+ for (const Edge &E : Graph->edges ()) {
638
639
Edges[Graph->getEdgeIndex (E)] = Graph->getNodeIndex (*E.getDest ());
639
640
EdgeValues[Graph->getEdgeIndex (E)] = E.getValue ();
640
641
}
@@ -651,74 +652,67 @@ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin(
651
652
LLVM_DEBUG (dbgs () << " Inserting LFENCEs... Done\n " );
652
653
LLVM_DEBUG (dbgs () << " Inserted " << FencesInserted << " fences\n " );
653
654
654
- Graph = GraphBuilder::trim (*Graph, MachineGadgetGraph::NodeSet{*Graph},
655
- CutEdges);
655
+ Graph = GraphBuilder::trim (*Graph, NodeSet{*Graph}, CutEdges);
656
656
} while (true );
657
657
658
658
return FencesInserted;
659
659
}
660
660
661
- int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithGreedyHeuristic (
661
+ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic (
662
662
MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const {
663
- LLVM_DEBUG (dbgs () << " Eliminating mitigated paths...\n " );
664
- Graph = trimMitigatedEdges (std::move (Graph));
665
- LLVM_DEBUG (dbgs () << " Eliminating mitigated paths... Done\n " );
663
+ // If `MF` does not have any fences, then no gadgets would have been
664
+ // mitigated at this point.
665
+ if (Graph->NumFences > 0 ) {
666
+ LLVM_DEBUG (dbgs () << " Eliminating mitigated paths...\n " );
667
+ Graph = trimMitigatedEdges (std::move (Graph));
668
+ LLVM_DEBUG (dbgs () << " Eliminating mitigated paths... Done\n " );
669
+ }
670
+
666
671
if (Graph->NumGadgets == 0 )
667
672
return 0 ;
668
673
669
674
LLVM_DEBUG (dbgs () << " Cutting edges...\n " );
670
- MachineGadgetGraph::NodeSet ElimNodes{*Graph}, GadgetSinks{*Graph};
671
- MachineGadgetGraph::EdgeSet ElimEdges{*Graph}, CutEdges{*Graph};
672
- auto IsCFGEdge = [&ElimEdges, &CutEdges](const MachineGadgetGraph::Edge &E) {
673
- return !ElimEdges.contains (E) && !CutEdges.contains (E) &&
674
- MachineGadgetGraph::isCFGEdge (E);
675
- };
676
- auto IsGadgetEdge = [&ElimEdges,
677
- &CutEdges](const MachineGadgetGraph::Edge &E) {
678
- return !ElimEdges.contains (E) && !CutEdges.contains (E) &&
679
- MachineGadgetGraph::isGadgetEdge (E);
680
- };
681
-
682
- // FIXME: this is O(E^2), we could probably do better.
683
- do {
684
- // Find the cheapest CFG edge that will eliminate a gadget (by being
685
- // egress from a SOURCE node or ingress to a SINK node), and cut it.
686
- const MachineGadgetGraph::Edge *CheapestSoFar = nullptr ;
687
-
688
- // First, collect all gadget source and sink nodes.
689
- MachineGadgetGraph::NodeSet GadgetSources{*Graph}, GadgetSinks{*Graph};
690
- for (const auto &N : Graph->nodes ()) {
691
- if (ElimNodes.contains (N))
675
+ EdgeSet CutEdges{*Graph};
676
+
677
+ // Begin by collecting all ingress CFG edges for each node
678
+ DenseMap<const Node *, SmallVector<const Edge *, 2 >> IngressEdgeMap;
679
+ for (const Edge &E : Graph->edges ())
680
+ if (MachineGadgetGraph::isCFGEdge (E))
681
+ IngressEdgeMap[E.getDest ()].push_back (&E);
682
+
683
+ // For each gadget edge, make cuts that guarantee the gadget will be
684
+ // mitigated. A computationally efficient way to achieve this is to either:
685
+ // (a) cut all egress CFG edges from the gadget source, or
686
+ // (b) cut all ingress CFG edges to the gadget sink.
687
+ //
688
+ // Moreover, the algorithm tries not to make a cut into a loop by preferring
689
+ // to make a (b)-type cut if the gadget source resides at a greater loop depth
690
+ // than the gadget sink, or an (a)-type cut otherwise.
691
+ for (const Node &N : Graph->nodes ()) {
692
+ for (const Edge &E : N.edges ()) {
693
+ if (!MachineGadgetGraph::isGadgetEdge (E))
692
694
continue ;
693
- for (const auto &E : N.edges ()) {
694
- if (IsGadgetEdge (E)) {
695
- GadgetSources.insert (N);
696
- GadgetSinks.insert (*E.getDest ());
697
- }
698
- }
699
- }
700
695
701
- // Next, look for the cheapest CFG edge which, when cut, is guaranteed to
702
- // mitigate at least one gadget by either:
703
- // (a) being egress from a gadget source, or
704
- // (b) being ingress to a gadget sink.
705
- for (const auto &N : Graph->nodes ()) {
706
- if (ElimNodes.contains (N))
707
- continue ;
708
- for (const auto &E : N.edges ()) {
709
- if (IsCFGEdge (E)) {
710
- if (GadgetSources.contains (N) || GadgetSinks.contains (*E.getDest ())) {
711
- if (!CheapestSoFar || E.getValue () < CheapestSoFar->getValue ())
712
- CheapestSoFar = &E;
713
- }
714
- }
715
- }
696
+ SmallVector<const Edge *, 2 > EgressEdges;
697
+ SmallVector<const Edge *, 2 > &IngressEdges = IngressEdgeMap[E.getDest ()];
698
+ for (const Edge &EgressEdge : N.edges ())
699
+ if (MachineGadgetGraph::isCFGEdge (EgressEdge))
700
+ EgressEdges.push_back (&EgressEdge);
701
+
702
+ int EgressCutCost = 0 , IngressCutCost = 0 ;
703
+ for (const Edge *EgressEdge : EgressEdges)
704
+ if (!CutEdges.contains (*EgressEdge))
705
+ EgressCutCost += EgressEdge->getValue ();
706
+ for (const Edge *IngressEdge : IngressEdges)
707
+ if (!CutEdges.contains (*IngressEdge))
708
+ IngressCutCost += IngressEdge->getValue ();
709
+
710
+ auto &EdgesToCut =
711
+ IngressCutCost < EgressCutCost ? IngressEdges : EgressEdges;
712
+ for (const Edge *E : EdgesToCut)
713
+ CutEdges.insert (*E);
716
714
}
717
-
718
- assert (CheapestSoFar && " Failed to cut an edge" );
719
- CutEdges.insert (*CheapestSoFar);
720
- ElimEdges.insert (*CheapestSoFar);
721
- } while (elimMitigatedEdgesAndNodes (*Graph, ElimEdges, ElimNodes));
715
+ }
722
716
LLVM_DEBUG (dbgs () << " Cutting edges... Done\n " );
723
717
LLVM_DEBUG (dbgs () << " Cut " << CutEdges.count () << " edges\n " );
724
718
@@ -734,8 +728,8 @@ int X86LoadValueInjectionLoadHardeningPass::insertFences(
734
728
MachineFunction &MF, MachineGadgetGraph &G,
735
729
EdgeSet &CutEdges /* in, out */ ) const {
736
730
int FencesInserted = 0 ;
737
- for (const auto &N : G.nodes ()) {
738
- for (const auto &E : N.edges ()) {
731
+ for (const Node &N : G.nodes ()) {
732
+ for (const Edge &E : N.edges ()) {
739
733
if (CutEdges.contains (E)) {
740
734
MachineInstr *MI = N.getValue (), *Prev;
741
735
MachineBasicBlock *MBB; // Insert an LFENCE in this MBB
@@ -751,7 +745,7 @@ int X86LoadValueInjectionLoadHardeningPass::insertFences(
751
745
Prev = MI->getPrevNode ();
752
746
// Remove all egress CFG edges from this branch because the inserted
753
747
// LFENCE prevents gadgets from crossing the branch.
754
- for (const auto &E : N.edges ()) {
748
+ for (const Edge &E : N.edges ()) {
755
749
if (MachineGadgetGraph::isCFGEdge (E))
756
750
CutEdges.insert (E);
757
751
}
0 commit comments