Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX] support switch statement with brx.idx (reland) #102550

Merged
merged 3 commits into from
Aug 9, 2024

Conversation

AlexMaclean
Copy link
Member

Add custom lowering for BR_JT DAG nodes to the brx.idx PTX instruction (PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx).
Depending on the heuristics in DAG selection, switch statements may now be lowered using brx.idx.

Note: this fixes the previous issue in #102400 by adding the isBarrier attribute to BRX_END

Add custom lowering for `BR_JT` DAG nodes to the `brx.idx` PTX
instruction ([PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx]
(https://docs.nvidia.com/cuda/parallel-thread-execution/#control-flow-instructions-brx-idx)).
Depending on the heuristics in DAG selection, `switch` statements may
now be lowered using `brx.idx`
@AlexMaclean AlexMaclean requested a review from Artem-B August 8, 2024 23:20
@AlexMaclean AlexMaclean self-assigned this Aug 8, 2024
@llvmbot llvmbot added backend:NVPTX llvm:SelectionDAG SelectionDAGISel as well labels Aug 8, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 8, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Alex MacLean (AlexMaclean)

Changes

Add custom lowering for BR_JT DAG nodes to the brx.idx PTX instruction (PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx).
Depending on the heuristics in DAG selection, switch statements may now be lowered using brx.idx.

Note: this fixes the previous issue in #102400 by adding the isBarrier attribute to BRX_END


Full diff: https://github.com/llvm/llvm-project/pull/102550.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+6-5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+42-3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+10)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+40)
  • (added) llvm/test/CodeGen/NVPTX/jump-table.ll (+69)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9ccdbab008aec8..5b2214fa66c40b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase {
   /// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
   virtual unsigned getJumpTableEncoding() const;
 
+  virtual MVT getJumpTableRegTy(const DataLayout &DL) const {
+    return getPointerTy(DL);
+  }
+
   virtual const MCExpr *
   LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/,
                             const MachineBasicBlock * /*MBB*/, unsigned /*uid*/,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1f4436fb3a4966..37ba62911ec70b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
   // Emit the code for the jump table
   assert(JT.SL && "Should set SDLoc for SelectionDAG!");
   assert(JT.Reg != -1U && "Should lower JT Header first!");
-  EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
+  EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout());
   SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy);
   SDValue Table = DAG.getJumpTable(JT.JTI, PTy);
   SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other,
@@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
   // This value may be smaller or larger than the target's pointer type, and
   // therefore require extension or truncating.
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout()));
+  SwitchOp =
+      DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout()));
 
   unsigned JumpTableReg =
-      FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout()));
-  SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl,
-                                    JumpTableReg, SwitchOp);
+      FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout()));
+  SDValue CopyTo =
+      DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp);
   JT.Reg = JumpTableReg;
 
   if (!JTH.FallthroughUnreachable) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 516fc7339a4bf3..bf647c88f00e28 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -25,6 +25,7 @@
 #include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineMemOperand.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::ROTR, MVT::i8, Expand);
   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
 
-  // Indirect branch is not supported.
-  // This also disables Jump Table creation.
-  setOperationAction(ISD::BR_JT, MVT::Other, Expand);
+  setOperationAction(ISD::BR_JT, MVT::Other, Custom);
   setOperationAction(ISD::BRIND, MVT::Other, Expand);
 
   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
@@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::Dummy)
     MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
     MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
+    MAKE_CASE(NVPTXISD::BrxEnd)
+    MAKE_CASE(NVPTXISD::BrxItem)
+    MAKE_CASE(NVPTXISD::BrxStart)
     MAKE_CASE(NVPTXISD::Tex1DFloatS32)
     MAKE_CASE(NVPTXISD::Tex1DFloatFloat)
     MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel)
@@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerFP_ROUND(Op, DAG);
   case ISD::FP_EXTEND:
     return LowerFP_EXTEND(Op, DAG);
+  case ISD::BR_JT:
+    return LowerBR_JT(Op, DAG);
   case ISD::VAARG:
     return LowerVAARG(Op, DAG);
   case ISD::VASTART:
@@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   }
 }
 
+SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Chain = Op.getOperand(0);
+  const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
+  SDValue Index = Op.getOperand(2);
+
+  unsigned JId = JT->getIndex();
+  MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
+  ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
+
+  SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
+
+  // Generate BrxStart node
+  SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
+  Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
+
+  // Generate BrxItem nodes
+  assert(!MBBs.empty());
+  for (MachineBasicBlock *MBB : MBBs.drop_back())
+    Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
+                        DAG.getBasicBlock(MBB), Chain.getValue(1));
+
+  // Generate BrxEnd nodes
+  SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
+                      IdV, Chain.getValue(1)};
+  SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
+
+  return BrxEnd;
+}
+
+// This will prevent AsmPrinter from trying to print the jump tables itself.
+unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
+  return MachineJumpTableInfo::EK_Inline;
+}
+
 // This function is almost a copy of SelectionDAG::expandVAArg().
 // The only diff is that this one produces loads from local address space.
 SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 63262961b363ed..32e6b044b0de1f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -62,6 +62,9 @@ enum NodeType : unsigned {
   BFI,
   PRMT,
   DYNAMIC_STACKALLOC,
+  BrxStart,
+  BrxItem,
+  BrxEnd,
   Dummy,
 
   LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
@@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering {
     return true;
   }
 
+  // The default is the same as pointer type, but brx.idx only accepts i32
+  MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
+
+  unsigned getJumpTableEncoding() const override;
+
   bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
 
   // The default is to transform llvm.ctlz(x, false) (where false indicates that
@@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering {
 
   SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 6a096fa5acea7c..d75dc8781f7802 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3880,6 +3880,46 @@ def DYNAMIC_STACKALLOC64 :
             [(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>,
             Requires<[hasPTX<73>, hasSM<52>]>;
 
+
+//
+// BRX
+//
+
+def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>;
+def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>;
+
+def brx_start :
+  SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile,
+         [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
+def brx_item :
+  SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile,
+         [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def brx_end :
+  SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile,
+         [SDNPHasChain, SDNPInGlue, SDNPSideEffect]>;
+
+let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in {
+
+  def BRX_START :
+    NVPTXInst<(outs), (ins i32imm:$id),
+              "$$L_brx_$id: .branchtargets",
+              [(brx_start (i32 imm:$id))]>;
+
+  def BRX_ITEM :
+    NVPTXInst<(outs), (ins brtarget:$target),
+              "\t$target,",
+              [(brx_item bb:$target)]>;
+
+  def BRX_END :
+    NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id),
+              "\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;",
+              [(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]> {
+      let isBarrier = 1;
+    }
+}
+
+
 include "NVPTXIntrinsics.td"
 
 //-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/jump-table.ll b/llvm/test/CodeGen/NVPTX/jump-table.ll
new file mode 100644
index 00000000000000..867e171a5840ae
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/jump-table.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+@out = addrspace(1) global i32 0, align 4
+
+define void @foo(i32 %i) {
+; CHECK-LABEL: foo(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.u32 %r2, [foo_param_0];
+; CHECK-NEXT:    setp.gt.u32 %p1, %r2, 3;
+; CHECK-NEXT:    @%p1 bra $L__BB0_6;
+; CHECK-NEXT:  // %bb.1: // %entry
+; CHECK-NEXT:    $L_brx_0: .branchtargets
+; CHECK-NEXT:     $L__BB0_2,
+; CHECK-NEXT:     $L__BB0_3,
+; CHECK-NEXT:     $L__BB0_4,
+; CHECK-NEXT:     $L__BB0_5;
+; CHECK-NEXT:    brx.idx %r2, $L_brx_0;
+; CHECK-NEXT:  $L__BB0_2: // %case0
+; CHECK-NEXT:    mov.b32 %r6, 0;
+; CHECK-NEXT:    st.global.u32 [out], %r6;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_4: // %case2
+; CHECK-NEXT:    mov.b32 %r4, 2;
+; CHECK-NEXT:    st.global.u32 [out], %r4;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_5: // %case3
+; CHECK-NEXT:    mov.b32 %r3, 3;
+; CHECK-NEXT:    st.global.u32 [out], %r3;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_3: // %case1
+; CHECK-NEXT:    mov.b32 %r5, 1;
+; CHECK-NEXT:    st.global.u32 [out], %r5;
+; CHECK-NEXT:  $L__BB0_6: // %end
+; CHECK-NEXT:    ret;
+entry:
+  switch i32 %i, label %end [
+    i32 0, label %case0
+    i32 1, label %case1
+    i32 2, label %case2
+    i32 3, label %case3
+  ]
+
+case0:
+  store i32 0, ptr addrspace(1) @out, align 4
+  br label %end
+
+case1:
+  store i32 1, ptr addrspace(1) @out, align 4
+  br label %end
+
+case2:
+  store i32 2, ptr addrspace(1) @out, align 4
+  br label %end
+
+case3:
+  store i32 3, ptr addrspace(1) @out, align 4
+  br label %end
+
+end:
+  ret void
+}

@llvmbot
Copy link
Member

llvmbot commented Aug 8, 2024

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Add custom lowering for BR_JT DAG nodes to the brx.idx PTX instruction (PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx).
Depending on the heuristics in DAG selection, switch statements may now be lowered using brx.idx.

Note: this fixes the previous issue in #102400 by adding the isBarrier attribute to BRX_END


Full diff: https://github.com/llvm/llvm-project/pull/102550.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+6-5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+42-3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+10)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+40)
  • (added) llvm/test/CodeGen/NVPTX/jump-table.ll (+69)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9ccdbab008aec8..5b2214fa66c40b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase {
   /// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
   virtual unsigned getJumpTableEncoding() const;
 
+  virtual MVT getJumpTableRegTy(const DataLayout &DL) const {
+    return getPointerTy(DL);
+  }
+
   virtual const MCExpr *
   LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/,
                             const MachineBasicBlock * /*MBB*/, unsigned /*uid*/,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1f4436fb3a4966..37ba62911ec70b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
   // Emit the code for the jump table
   assert(JT.SL && "Should set SDLoc for SelectionDAG!");
   assert(JT.Reg != -1U && "Should lower JT Header first!");
-  EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
+  EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout());
   SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy);
   SDValue Table = DAG.getJumpTable(JT.JTI, PTy);
   SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other,
@@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
   // This value may be smaller or larger than the target's pointer type, and
   // therefore require extension or truncating.
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout()));
+  SwitchOp =
+      DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout()));
 
   unsigned JumpTableReg =
-      FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout()));
-  SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl,
-                                    JumpTableReg, SwitchOp);
+      FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout()));
+  SDValue CopyTo =
+      DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp);
   JT.Reg = JumpTableReg;
 
   if (!JTH.FallthroughUnreachable) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 516fc7339a4bf3..bf647c88f00e28 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -25,6 +25,7 @@
 #include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineMemOperand.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::ROTR, MVT::i8, Expand);
   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
 
-  // Indirect branch is not supported.
-  // This also disables Jump Table creation.
-  setOperationAction(ISD::BR_JT, MVT::Other, Expand);
+  setOperationAction(ISD::BR_JT, MVT::Other, Custom);
   setOperationAction(ISD::BRIND, MVT::Other, Expand);
 
   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
@@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::Dummy)
     MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
     MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
+    MAKE_CASE(NVPTXISD::BrxEnd)
+    MAKE_CASE(NVPTXISD::BrxItem)
+    MAKE_CASE(NVPTXISD::BrxStart)
     MAKE_CASE(NVPTXISD::Tex1DFloatS32)
     MAKE_CASE(NVPTXISD::Tex1DFloatFloat)
     MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel)
@@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerFP_ROUND(Op, DAG);
   case ISD::FP_EXTEND:
     return LowerFP_EXTEND(Op, DAG);
+  case ISD::BR_JT:
+    return LowerBR_JT(Op, DAG);
   case ISD::VAARG:
     return LowerVAARG(Op, DAG);
   case ISD::VASTART:
@@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   }
 }
 
+SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Chain = Op.getOperand(0);
+  const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
+  SDValue Index = Op.getOperand(2);
+
+  unsigned JId = JT->getIndex();
+  MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
+  ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
+
+  SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
+
+  // Generate BrxStart node
+  SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
+  Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
+
+  // Generate BrxItem nodes
+  assert(!MBBs.empty());
+  for (MachineBasicBlock *MBB : MBBs.drop_back())
+    Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
+                        DAG.getBasicBlock(MBB), Chain.getValue(1));
+
+  // Generate BrxEnd nodes
+  SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
+                      IdV, Chain.getValue(1)};
+  SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
+
+  return BrxEnd;
+}
+
+// This will prevent AsmPrinter from trying to print the jump tables itself.
+unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
+  return MachineJumpTableInfo::EK_Inline;
+}
+
 // This function is almost a copy of SelectionDAG::expandVAArg().
 // The only diff is that this one produces loads from local address space.
 SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 63262961b363ed..32e6b044b0de1f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -62,6 +62,9 @@ enum NodeType : unsigned {
   BFI,
   PRMT,
   DYNAMIC_STACKALLOC,
+  BrxStart,
+  BrxItem,
+  BrxEnd,
   Dummy,
 
   LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
@@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering {
     return true;
   }
 
+  // The default is the same as pointer type, but brx.idx only accepts i32
+  MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
+
+  unsigned getJumpTableEncoding() const override;
+
   bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
 
   // The default is to transform llvm.ctlz(x, false) (where false indicates that
@@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering {
 
   SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 6a096fa5acea7c..d75dc8781f7802 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3880,6 +3880,46 @@ def DYNAMIC_STACKALLOC64 :
             [(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>,
             Requires<[hasPTX<73>, hasSM<52>]>;
 
+
+//
+// BRX
+//
+
+def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>;
+def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>;
+
+def brx_start :
+  SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile,
+         [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
+def brx_item :
+  SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile,
+         [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def brx_end :
+  SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile,
+         [SDNPHasChain, SDNPInGlue, SDNPSideEffect]>;
+
+let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in {
+
+  def BRX_START :
+    NVPTXInst<(outs), (ins i32imm:$id),
+              "$$L_brx_$id: .branchtargets",
+              [(brx_start (i32 imm:$id))]>;
+
+  def BRX_ITEM :
+    NVPTXInst<(outs), (ins brtarget:$target),
+              "\t$target,",
+              [(brx_item bb:$target)]>;
+
+  def BRX_END :
+    NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id),
+              "\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;",
+              [(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]> {
+      let isBarrier = 1;
+    }
+}
+
+
 include "NVPTXIntrinsics.td"
 
 //-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/jump-table.ll b/llvm/test/CodeGen/NVPTX/jump-table.ll
new file mode 100644
index 00000000000000..867e171a5840ae
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/jump-table.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+@out = addrspace(1) global i32 0, align 4
+
+define void @foo(i32 %i) {
+; CHECK-LABEL: foo(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.u32 %r2, [foo_param_0];
+; CHECK-NEXT:    setp.gt.u32 %p1, %r2, 3;
+; CHECK-NEXT:    @%p1 bra $L__BB0_6;
+; CHECK-NEXT:  // %bb.1: // %entry
+; CHECK-NEXT:    $L_brx_0: .branchtargets
+; CHECK-NEXT:     $L__BB0_2,
+; CHECK-NEXT:     $L__BB0_3,
+; CHECK-NEXT:     $L__BB0_4,
+; CHECK-NEXT:     $L__BB0_5;
+; CHECK-NEXT:    brx.idx %r2, $L_brx_0;
+; CHECK-NEXT:  $L__BB0_2: // %case0
+; CHECK-NEXT:    mov.b32 %r6, 0;
+; CHECK-NEXT:    st.global.u32 [out], %r6;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_4: // %case2
+; CHECK-NEXT:    mov.b32 %r4, 2;
+; CHECK-NEXT:    st.global.u32 [out], %r4;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_5: // %case3
+; CHECK-NEXT:    mov.b32 %r3, 3;
+; CHECK-NEXT:    st.global.u32 [out], %r3;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_3: // %case1
+; CHECK-NEXT:    mov.b32 %r5, 1;
+; CHECK-NEXT:    st.global.u32 [out], %r5;
+; CHECK-NEXT:  $L__BB0_6: // %end
+; CHECK-NEXT:    ret;
+entry:
+  switch i32 %i, label %end [
+    i32 0, label %case0
+    i32 1, label %case1
+    i32 2, label %case2
+    i32 3, label %case3
+  ]
+
+case0:
+  store i32 0, ptr addrspace(1) @out, align 4
+  br label %end
+
+case1:
+  store i32 1, ptr addrspace(1) @out, align 4
+  br label %end
+
+case2:
+  store i32 2, ptr addrspace(1) @out, align 4
+  br label %end
+
+case3:
+  store i32 3, ptr addrspace(1) @out, align 4
+  br label %end
+
+end:
+  ret void
+}

@AlexMaclean AlexMaclean merged commit ccc3127 into llvm:main Aug 9, 2024
8 checks passed
kutemeikito added a commit to kutemeikito/llvm-project that referenced this pull request Aug 10, 2024
* 'main' of https://github.com/llvm/llvm-project: (700 commits)
  [SandboxIR][NFC] SingleLLVMInstructionImpl class (llvm#102687)
  [ThinLTO]Clean up 'import-assume-unique-local' flag. (llvm#102424)
  [nsan] Make #include more conventional
  [SandboxIR][NFC] Use Tracker.emplaceIfTracking()
  [libc]  Moved range_reduction_double ifdef statement (llvm#102659)
  [libc] Fix CFP long double and add tests (llvm#102660)
  [TargetLowering] Handle vector types in expandFixedPointMul (llvm#102635)
  [compiler-rt][NFC] Replace environment variable with %t (llvm#102197)
  [UnitTests] Convert a test to use opaque pointers (llvm#102668)
  [CodeGen][NFCI] Don't re-implement parts of ASTContext::getIntWidth (llvm#101765)
  [SandboxIR] Clean up tracking code with the help of emplaceIfTracking() (llvm#102406)
  [mlir][bazel] remove extra blanks in mlir-tblgen test
  [NVPTX][NFC] Update tests to use bfloat type (llvm#101493)
  [mlir] Add support for parsing nested PassPipelineOptions (llvm#101118)
  [mlir][bazel] add missing td dependency in mlir-tblgen test
  [flang][cuda] Fix lib dependency
  [libc] Clean up remaining use of *_WIDTH macros in printf (llvm#102679)
  [flang][cuda] Convert cuf.alloc for box to fir.alloca in device context (llvm#102662)
  [SandboxIR] Implement the InsertElementInst class (llvm#102404)
  [libc] Fix use of cpp::numeric_limits<...>::digits (llvm#102674)
  [mlir][ODS] Verify type constraints in Types and Attributes (llvm#102326)
  [LTO] enable `ObjCARCContractPass` only on optimized build  (llvm#101114)
  [mlir][ODS] Consistent `cppType` / `cppClassName` usage (llvm#102657)
  [lldb] Move definition of SBSaveCoreOptions dtor out of header (llvm#102539)
  [libc] Use cpp::numeric_limits in preference to C23 <limits.h> macros (llvm#102665)
  [clang] Implement -fptrauth-auth-traps. (llvm#102417)
  [LLVM][rtsan] rtsan transform to preserve CFGAnalyses (llvm#102651)
  Revert "[AMDGPU] Move `AMDGPUAttributorPass` to full LTO post link stage (llvm#102086)"
  [RISCV][GISel] Add missing tests for G_CTLZ/CTTZ instruction selection. NFC
  Return available function types for BindingDecls. (llvm#102196)
  [clang] Wire -fptrauth-returns to "ptrauth-returns" fn attribute. (llvm#102416)
  [RISCV] Remove riscv-experimental-rv64-legal-i32. (llvm#102509)
  [RISCV] Move PseudoVSET(I)VLI expansion to use PseudoInstExpansion. (llvm#102496)
  [NVPTX] support switch statement with brx.idx (reland) (llvm#102550)
  [libc][newhdrgen]sorted function names in yaml (llvm#102544)
  [GlobalIsel] Combine G_ADD and G_SUB with constants (llvm#97771)
  Suppress spurious warnings due to R_RISCV_SET_ULEB128
  [scudo] Separated committed and decommitted entries. (llvm#101409)
  [MIPS] Fix missing ANDI optimization (llvm#97689)
  [Clang] Add env var for nvptx-arch/amdgpu-arch timeout (llvm#102521)
  [asan] Switch allocator to dynamic base address (llvm#98511)
  [AMDGPU] Move `AMDGPUAttributorPass` to full LTO post link stage (llvm#102086)
  [libc][math][c23] Add fadd{l,f128} C23 math functions (llvm#102531)
  [mlir][bazel] revert bazel rule change for DLTITransformOps
  [msan] Support vst{2,3,4}_lane instructions (llvm#101215)
  Revert "[MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)"
  [X86] pr57673.ll - generate MIR test checks
  [mlir][vector][test] Split tests from vector-transfer-flatten.mlir (llvm#102584)
  [mlir][bazel] add bazel rule for DLTITransformOps
  OpenMPOpt: Remove dead include
  [IR] Add method to GlobalVariable to change type of initializer. (llvm#102553)
  [flang][cuda] Force default allocator in device code (llvm#102238)
  [llvm] Construct SmallVector<SDValue> with ArrayRef (NFC) (llvm#102578)
  [MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)
  [AMDGPU][AsmParser][NFC] Remove a misleading comment. (llvm#102604)
  [Arm][AArch64][Clang] Respect function's branch protection attributes. (llvm#101978)
  [mlir] Verifier: steal bit to track seen instead of set. (llvm#102626)
  [Clang] Fix Handling of Init Capture with Parameter Packs in LambdaScopeForCallOperatorInstantiationRAII (llvm#100766)
  [X86] Convert truncsat clamping patterns to use SDPatternMatch. NFC.
  [gn] Give two scripts argparse.RawDescriptionHelpFormatter
  [bazel] Add missing dep for the SPIRVToLLVM target
  [Clang] Simplify specifying passes via -Xoffload-linker (llvm#102483)
  [bazel] Port for d45de80
  [SelectionDAG] Use unaligned store/load to move AVX registers onto stack for `insertelement` (llvm#82130)
  [Clang][OMPX] Add the code generation for multi-dim `num_teams` (llvm#101407)
  [ARM] Regenerate big-endian-vmov.ll. NFC
  [AMDGPU][AsmParser][NFCI] All NamedIntOperands to be of the i32 type. (llvm#102616)
  [libc][math][c23] Add totalorderl function. (llvm#102564)
  [mlir][spirv] Support `memref` in `convert-to-spirv` pass (llvm#102534)
  [MLIR][GPU-LLVM] Convert `gpu.func` to `llvm.func` (llvm#101664)
  Fix a unit test input file (llvm#102567)
  [llvm-readobj][COFF] Dump hybrid objects for ARM64X files. (llvm#102245)
  AMDGPU/NewPM: Port SIFixSGPRCopies to new pass manager (llvm#102614)
  [MemoryBuiltins] Simplify getCalledFunction() helper (NFC)
  [AArch64] Add invalid 1 x vscale costs for reductions and reduction-operations. (llvm#102105)
  [MemoryBuiltins] Handle allocator attributes on call-site
  LSV/test/AArch64: add missing lit.local.cfg; fix build (llvm#102607)
  Revert "Enable logf128 constant folding for hosts with 128bit floats (llvm#96287)"
  [RISCV] Add Syntacore SCR5 RV32/64 processors definition (llvm#102285)
  [InstCombine] Remove unnecessary RUN line from test (NFC)
  [flang][OpenMP] Handle multiple ranges in `num_teams` clause (llvm#102535)
  [mlir][vector] Add tests for scalable vectors in one-shot-bufferize.mlir (llvm#102361)
  [mlir][vector] Disable `vector.matrix_multiply` for scalable vectors (llvm#102573)
  [clang] Implement CWG2627 Bit-fields and narrowing conversions (llvm#78112)
  [NFC] Use references to avoid copying (llvm#99863)
  Revert "[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (llvm#100731)" (llvm#102457)
  [IRBuilder] Generate nuw GEPs for struct member accesses (llvm#99538)
  [bazel] Port for 9b06e25
  [CodeGen][NewPM] Improve start/stop pass error message CodeGenPassBuilder (llvm#102591)
  [AArch64] Implement TRBMPAM_EL1 system register (llvm#102485)
  [InstCombine] Fixing wrong select folding in vectors with undef elements (llvm#102244)
  [AArch64] Sink operands to fmuladd. (llvm#102297)
  LSV: document hang reported in llvm#37865 (llvm#102479)
  Enable logf128 constant folding for hosts with 128bit floats (llvm#96287)
  [RISCV][clang] Remove bfloat base type in non-zvfbfmin vcreate (llvm#102146)
  [RISCV][clang] Add missing `zvfbfmin` to `vget_v` intrinsic (llvm#102149)
  [mlir][vector] Add mask elimination transform (llvm#99314)
  [Clang][Interp] Fix display of syntactically-invalid note for member function calls (llvm#102170)
  [bazel] Port for 3fffa6d
  [DebugInfo][RemoveDIs] Use iterator-inserters in clang (llvm#102006)
  ...

Signed-off-by: Edwiin Kusuma Jaya <kutemeikito0905@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:NVPTX llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants