Skip to content

Commit ccc3127

Browse files
authored
[NVPTX] support switch statement with brx.idx (reland) (#102550)
Add custom lowering for `BR_JT` DAG nodes to the `brx.idx` PTX instruction ([PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx] (https://docs.nvidia.com/cuda/parallel-thread-execution/#control-flow-instructions-brx-idx)). Depending on the heuristics in DAG selection, `switch` statements may now be lowered using `brx.idx`. Note: this fixes the previous issue in #102400 by adding the isBarrier attribute to BRX_END
1 parent b6cbd01 commit ccc3127

File tree

6 files changed

+272
-8
lines changed

6 files changed

+272
-8
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase {
38433843
/// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
38443844
virtual unsigned getJumpTableEncoding() const;
38453845

3846+
virtual MVT getJumpTableRegTy(const DataLayout &DL) const {
3847+
return getPointerTy(DL);
3848+
}
3849+
38463850
virtual const MCExpr *
38473851
LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/,
38483852
const MachineBasicBlock * /*MBB*/, unsigned /*uid*/,

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
29772977
// Emit the code for the jump table
29782978
assert(JT.SL && "Should set SDLoc for SelectionDAG!");
29792979
assert(JT.Reg != -1U && "Should lower JT Header first!");
2980-
EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
2980+
EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout());
29812981
SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy);
29822982
SDValue Table = DAG.getJumpTable(JT.JTI, PTy);
29832983
SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other,
@@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
30053005
// This value may be smaller or larger than the target's pointer type, and
30063006
// therefore require extension or truncating.
30073007
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
3008-
SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout()));
3008+
SwitchOp =
3009+
DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout()));
30093010

30103011
unsigned JumpTableReg =
3011-
FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout()));
3012-
SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl,
3013-
JumpTableReg, SwitchOp);
3012+
FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout()));
3013+
SDValue CopyTo =
3014+
DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp);
30143015
JT.Reg = JumpTableReg;
30153016

30163017
if (!JTH.FallthroughUnreachable) {

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/CodeGen/Analysis.h"
2626
#include "llvm/CodeGen/ISDOpcodes.h"
2727
#include "llvm/CodeGen/MachineFunction.h"
28+
#include "llvm/CodeGen/MachineJumpTableInfo.h"
2829
#include "llvm/CodeGen/MachineMemOperand.h"
2930
#include "llvm/CodeGen/SelectionDAG.h"
3031
#include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
582583
setOperationAction(ISD::ROTR, MVT::i8, Expand);
583584
setOperationAction(ISD::BSWAP, MVT::i16, Expand);
584585

585-
// Indirect branch is not supported.
586-
// This also disables Jump Table creation.
587-
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
586+
setOperationAction(ISD::BR_JT, MVT::Other, Custom);
588587
setOperationAction(ISD::BRIND, MVT::Other, Expand);
589588

590589
setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
@@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
945944
MAKE_CASE(NVPTXISD::Dummy)
946945
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
947946
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
947+
MAKE_CASE(NVPTXISD::BrxEnd)
948+
MAKE_CASE(NVPTXISD::BrxItem)
949+
MAKE_CASE(NVPTXISD::BrxStart)
948950
MAKE_CASE(NVPTXISD::Tex1DFloatS32)
949951
MAKE_CASE(NVPTXISD::Tex1DFloatFloat)
950952
MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel)
@@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
27852787
return LowerFP_ROUND(Op, DAG);
27862788
case ISD::FP_EXTEND:
27872789
return LowerFP_EXTEND(Op, DAG);
2790+
case ISD::BR_JT:
2791+
return LowerBR_JT(Op, DAG);
27882792
case ISD::VAARG:
27892793
return LowerVAARG(Op, DAG);
27902794
case ISD::VASTART:
@@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28102814
}
28112815
}
28122816

2817+
SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
2818+
SDLoc DL(Op);
2819+
SDValue Chain = Op.getOperand(0);
2820+
const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
2821+
SDValue Index = Op.getOperand(2);
2822+
2823+
unsigned JId = JT->getIndex();
2824+
MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
2825+
ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
2826+
2827+
SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
2828+
2829+
// Generate BrxStart node
2830+
SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
2831+
Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
2832+
2833+
// Generate BrxItem nodes
2834+
assert(!MBBs.empty());
2835+
for (MachineBasicBlock *MBB : MBBs.drop_back())
2836+
Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
2837+
DAG.getBasicBlock(MBB), Chain.getValue(1));
2838+
2839+
// Generate BrxEnd nodes
2840+
SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
2841+
IdV, Chain.getValue(1)};
2842+
SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
2843+
2844+
return BrxEnd;
2845+
}
2846+
2847+
// This will prevent AsmPrinter from trying to print the jump tables itself.
2848+
unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
2849+
return MachineJumpTableInfo::EK_Inline;
2850+
}
2851+
28132852
// This function is almost a copy of SelectionDAG::expandVAArg().
28142853
// The only diff is that this one produces loads from local address space.
28152854
SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ enum NodeType : unsigned {
6262
BFI,
6363
PRMT,
6464
DYNAMIC_STACKALLOC,
65+
BrxStart,
66+
BrxItem,
67+
BrxEnd,
6568
Dummy,
6669

6770
LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
@@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering {
580583
return true;
581584
}
582585

586+
// The default is the same as pointer type, but brx.idx only accepts i32
587+
MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
588+
589+
unsigned getJumpTableEncoding() const override;
590+
583591
bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
584592

585593
// The default is to transform llvm.ctlz(x, false) (where false indicates that
@@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering {
637645

638646
SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
639647

648+
SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
649+
640650
SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
641651
SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
642652

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3880,6 +3880,46 @@ def DYNAMIC_STACKALLOC64 :
38803880
[(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>,
38813881
Requires<[hasPTX<73>, hasSM<52>]>;
38823882

3883+
3884+
//
3885+
// BRX
3886+
//
3887+
3888+
def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
3889+
def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>;
3890+
def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>;
3891+
3892+
def brx_start :
3893+
SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile,
3894+
[SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
3895+
def brx_item :
3896+
SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile,
3897+
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
3898+
def brx_end :
3899+
SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile,
3900+
[SDNPHasChain, SDNPInGlue, SDNPSideEffect]>;
3901+
3902+
let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in {
3903+
3904+
def BRX_START :
3905+
NVPTXInst<(outs), (ins i32imm:$id),
3906+
"$$L_brx_$id: .branchtargets",
3907+
[(brx_start (i32 imm:$id))]>;
3908+
3909+
def BRX_ITEM :
3910+
NVPTXInst<(outs), (ins brtarget:$target),
3911+
"\t$target,",
3912+
[(brx_item bb:$target)]>;
3913+
3914+
def BRX_END :
3915+
NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id),
3916+
"\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;",
3917+
[(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]> {
3918+
let isBarrier = 1;
3919+
}
3920+
}
3921+
3922+
38833923
include "NVPTXIntrinsics.td"
38843924

38853925
//-----------------------------------

llvm/test/CodeGen/NVPTX/jump-table.ll

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}
4+
5+
target triple = "nvptx64-nvidia-cuda"
6+
7+
@out = addrspace(1) global i32 0, align 4
8+
9+
define void @foo(i32 %i) {
10+
; CHECK-LABEL: foo(
11+
; CHECK: {
12+
; CHECK-NEXT: .reg .pred %p<2>;
13+
; CHECK-NEXT: .reg .b32 %r<7>;
14+
; CHECK-EMPTY:
15+
; CHECK-NEXT: // %bb.0: // %entry
16+
; CHECK-NEXT: ld.param.u32 %r2, [foo_param_0];
17+
; CHECK-NEXT: setp.gt.u32 %p1, %r2, 3;
18+
; CHECK-NEXT: @%p1 bra $L__BB0_6;
19+
; CHECK-NEXT: // %bb.1: // %entry
20+
; CHECK-NEXT: $L_brx_0: .branchtargets
21+
; CHECK-NEXT: $L__BB0_2,
22+
; CHECK-NEXT: $L__BB0_3,
23+
; CHECK-NEXT: $L__BB0_4,
24+
; CHECK-NEXT: $L__BB0_5;
25+
; CHECK-NEXT: brx.idx %r2, $L_brx_0;
26+
; CHECK-NEXT: $L__BB0_2: // %case0
27+
; CHECK-NEXT: mov.b32 %r6, 0;
28+
; CHECK-NEXT: st.global.u32 [out], %r6;
29+
; CHECK-NEXT: bra.uni $L__BB0_6;
30+
; CHECK-NEXT: $L__BB0_4: // %case2
31+
; CHECK-NEXT: mov.b32 %r4, 2;
32+
; CHECK-NEXT: st.global.u32 [out], %r4;
33+
; CHECK-NEXT: bra.uni $L__BB0_6;
34+
; CHECK-NEXT: $L__BB0_5: // %case3
35+
; CHECK-NEXT: mov.b32 %r3, 3;
36+
; CHECK-NEXT: st.global.u32 [out], %r3;
37+
; CHECK-NEXT: bra.uni $L__BB0_6;
38+
; CHECK-NEXT: $L__BB0_3: // %case1
39+
; CHECK-NEXT: mov.b32 %r5, 1;
40+
; CHECK-NEXT: st.global.u32 [out], %r5;
41+
; CHECK-NEXT: $L__BB0_6: // %end
42+
; CHECK-NEXT: ret;
43+
entry:
44+
switch i32 %i, label %end [
45+
i32 0, label %case0
46+
i32 1, label %case1
47+
i32 2, label %case2
48+
i32 3, label %case3
49+
]
50+
51+
case0:
52+
store i32 0, ptr addrspace(1) @out, align 4
53+
br label %end
54+
55+
case1:
56+
store i32 1, ptr addrspace(1) @out, align 4
57+
br label %end
58+
59+
case2:
60+
store i32 2, ptr addrspace(1) @out, align 4
61+
br label %end
62+
63+
case3:
64+
store i32 3, ptr addrspace(1) @out, align 4
65+
br label %end
66+
67+
end:
68+
ret void
69+
}
70+
71+
72+
define i32 @test2(i32 %tmp158) {
73+
; CHECK-LABEL: test2(
74+
; CHECK: {
75+
; CHECK-NEXT: .reg .pred %p<6>;
76+
; CHECK-NEXT: .reg .b32 %r<10>;
77+
; CHECK-EMPTY:
78+
; CHECK-NEXT: // %bb.0: // %entry
79+
; CHECK-NEXT: ld.param.u32 %r1, [test2_param_0];
80+
; CHECK-NEXT: setp.gt.s32 %p1, %r1, 119;
81+
; CHECK-NEXT: @%p1 bra $L__BB1_4;
82+
; CHECK-NEXT: // %bb.1: // %entry
83+
; CHECK-NEXT: setp.lt.u32 %p4, %r1, 6;
84+
; CHECK-NEXT: @%p4 bra $L__BB1_3;
85+
; CHECK-NEXT: // %bb.2: // %entry
86+
; CHECK-NEXT: setp.lt.s32 %p5, %r1, -2147483645;
87+
; CHECK-NEXT: @%p5 bra $L__BB1_3;
88+
; CHECK-NEXT: bra.uni $L__BB1_6;
89+
; CHECK-NEXT: $L__BB1_4: // %entry
90+
; CHECK-NEXT: add.s32 %r2, %r1, -120;
91+
; CHECK-NEXT: setp.gt.u32 %p2, %r2, 5;
92+
; CHECK-NEXT: @%p2 bra $L__BB1_5;
93+
; CHECK-NEXT: // %bb.12: // %entry
94+
; CHECK-NEXT: $L_brx_0: .branchtargets
95+
; CHECK-NEXT: $L__BB1_3,
96+
; CHECK-NEXT: $L__BB1_7,
97+
; CHECK-NEXT: $L__BB1_8,
98+
; CHECK-NEXT: $L__BB1_9,
99+
; CHECK-NEXT: $L__BB1_10,
100+
; CHECK-NEXT: $L__BB1_11;
101+
; CHECK-NEXT: brx.idx %r2, $L_brx_0;
102+
; CHECK-NEXT: $L__BB1_7: // %bb339
103+
; CHECK-NEXT: mov.b32 %r7, 12;
104+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r7;
105+
; CHECK-NEXT: ret;
106+
; CHECK-NEXT: $L__BB1_5: // %entry
107+
; CHECK-NEXT: setp.eq.s32 %p3, %r1, 1024;
108+
; CHECK-NEXT: @%p3 bra $L__BB1_3;
109+
; CHECK-NEXT: bra.uni $L__BB1_6;
110+
; CHECK-NEXT: $L__BB1_3: // %bb338
111+
; CHECK-NEXT: mov.b32 %r8, 11;
112+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r8;
113+
; CHECK-NEXT: ret;
114+
; CHECK-NEXT: $L__BB1_10: // %bb342
115+
; CHECK-NEXT: mov.b32 %r4, 15;
116+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
117+
; CHECK-NEXT: ret;
118+
; CHECK-NEXT: $L__BB1_6: // %bb336
119+
; CHECK-NEXT: mov.b32 %r9, 10;
120+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r9;
121+
; CHECK-NEXT: ret;
122+
; CHECK-NEXT: $L__BB1_8: // %bb340
123+
; CHECK-NEXT: mov.b32 %r6, 13;
124+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r6;
125+
; CHECK-NEXT: ret;
126+
; CHECK-NEXT: $L__BB1_9: // %bb341
127+
; CHECK-NEXT: mov.b32 %r5, 14;
128+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
129+
; CHECK-NEXT: ret;
130+
; CHECK-NEXT: $L__BB1_11: // %bb343
131+
; CHECK-NEXT: mov.b32 %r3, 18;
132+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3;
133+
; CHECK-NEXT: ret;
134+
entry:
135+
switch i32 %tmp158, label %bb336 [
136+
i32 -2147483648, label %bb338
137+
i32 -2147483647, label %bb338
138+
i32 -2147483646, label %bb338
139+
i32 120, label %bb338
140+
i32 121, label %bb339
141+
i32 122, label %bb340
142+
i32 123, label %bb341
143+
i32 124, label %bb342
144+
i32 125, label %bb343
145+
i32 126, label %bb336
146+
i32 1024, label %bb338
147+
i32 0, label %bb338
148+
i32 1, label %bb338
149+
i32 2, label %bb338
150+
i32 3, label %bb338
151+
i32 4, label %bb338
152+
i32 5, label %bb338
153+
]
154+
155+
bb336:
156+
ret i32 10
157+
bb338:
158+
ret i32 11
159+
bb339:
160+
ret i32 12
161+
bb340:
162+
ret i32 13
163+
bb341:
164+
ret i32 14
165+
bb342:
166+
ret i32 15
167+
bb343:
168+
ret i32 18
169+
170+
}

0 commit comments

Comments
 (0)