Skip to content

Commit cbd3f25

Browse files
authored
[NVPTX] Support inline asm with 128-bit operand in NVPTX backend (#97113)
This change supports the 128-bit operands for inline ptx asm, both input and output.\ \ The major changes are: - Tablegen:\     Define Int128Regs in NVPTXRegisterInfo.td. But this register does not set as general register type in NVPTX backend so that this change will not influence the codegen without inline asm.\     Define three NVPTX intrinsics, IMOV128rr, V2I64toI128 and I128toV2I64. The first one moves a register, the second one moves two 64-bit registers into one 128-bit register, and the third one just does the opposite. - NVPTXISelLowering & NVPTXISelDAGToDAG:\     Custom lowering CopyToReg and CopyFromReg with 128-bit operands. CopyToReg deals with the inputs of the inline asm and the CopyFromReg deals with the outputs.\     CopyToReg is custom lowered into a V2I64toI128, which takes in the expanded values(Lo and Hi) of the input, and moves into a 128-bit reg.\     CopyFromReg is custom lowered by adding a I128toV2I64, which breaks down the 128-bit outputs of inline asm into the expanded values.
1 parent e25e800 commit cbd3f25

15 files changed

+529
-3
lines changed

clang/lib/Basic/Targets/NVPTX.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXTargetInfo : public TargetInfo {
105105
case 'l':
106106
case 'f':
107107
case 'd':
108+
case 'q':
108109
Info.setAllowsRegister();
109110
return true;
110111
}

llvm/docs/LangRef.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5381,6 +5381,7 @@ NVPTX:
53815381
- ``c`` or ``h``: A 16-bit integer register.
53825382
- ``r``: A 32-bit integer register.
53835383
- ``l`` or ``N``: A 64-bit integer register.
5384+
- ``q``: A 128-bit integer register.
53845385
- ``f``: A 32-bit float register.
53855386
- ``d``: A 64-bit float register.
53865387

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) const {
6060
case 6:
6161
OS << "%fd";
6262
break;
63+
case 7:
64+
OS << "%rq";
65+
break;
6366
}
6467

6568
unsigned VReg = Reg.id() & 0x0FFFFFFF;

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
315315
Ret = (5 << 28);
316316
} else if (RC == &NVPTX::Float64RegsRegClass) {
317317
Ret = (6 << 28);
318+
} else if (RC == &NVPTX::Int128RegsRegClass) {
319+
Ret = (7 << 28);
318320
} else {
319321
report_fatal_error("Bad register class");
320322
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,20 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
519519
if (tryConstantFP(N))
520520
return;
521521
break;
522+
case ISD::CopyToReg: {
523+
if (N->getOperand(1).getValueType() == MVT::i128) {
524+
SelectV2I64toI128(N);
525+
return;
526+
}
527+
break;
528+
}
529+
case ISD::CopyFromReg: {
530+
if (N->getOperand(1).getValueType() == MVT::i128) {
531+
SelectI128toV2I64(N);
532+
return;
533+
}
534+
break;
535+
}
522536
default:
523537
break;
524538
}
@@ -3798,6 +3812,60 @@ bool NVPTXDAGToDAGISel::SelectInlineAsmMemoryOperand(
37983812
return true;
37993813
}
38003814

3815+
void NVPTXDAGToDAGISel::SelectV2I64toI128(SDNode *N) {
3816+
// Lower a CopyToReg with two 64-bit inputs
3817+
// Dst:i128, lo:i64, hi:i64
3818+
//
3819+
// CopyToReg Dst, lo, hi;
3820+
//
3821+
// ==>
3822+
//
3823+
// tmp = V2I64toI128 {lo, hi};
3824+
// CopyToReg Dst, tmp;
3825+
SDValue Dst = N->getOperand(1);
3826+
SDValue Lo = N->getOperand(2);
3827+
SDValue Hi = N->getOperand(3);
3828+
3829+
SDLoc DL(N);
3830+
SDNode *Mov =
3831+
CurDAG->getMachineNode(NVPTX::V2I64toI128, DL, MVT::i128, {Lo, Hi});
3832+
3833+
SmallVector<SDValue, 4> NewOps(N->getNumOperands() - 1);
3834+
NewOps[0] = N->getOperand(0);
3835+
NewOps[1] = Dst;
3836+
NewOps[2] = SDValue(Mov, 0);
3837+
if (N->getNumOperands() == 5)
3838+
NewOps[3] = N->getOperand(4);
3839+
SDValue NewValue = CurDAG->getNode(ISD::CopyToReg, DL, SmallVector<EVT>(N->values()), NewOps);
3840+
3841+
ReplaceNode(N, NewValue.getNode());
3842+
}
3843+
3844+
void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
3845+
// Lower CopyFromReg from a 128-bit regs to two 64-bit regs
3846+
// Dst:i128, Src:i128
3847+
//
3848+
// {lo, hi} = CopyFromReg Src
3849+
//
3850+
// ==>
3851+
//
3852+
// {lo, hi} = I128toV2I64 Src
3853+
//
3854+
SDValue Ch = N->getOperand(0);
3855+
SDValue Src = N->getOperand(1);
3856+
SDValue Glue = N->getOperand(2);
3857+
SDLoc DL(N);
3858+
3859+
// Add Glue and Ch to the operands and results to avoid break the execution
3860+
// order
3861+
SDNode *Mov = CurDAG->getMachineNode(
3862+
NVPTX::I128toV2I64, DL,
3863+
{MVT::i64, MVT::i64, Ch.getValueType(), Glue.getValueType()},
3864+
{Src, Ch, Glue});
3865+
3866+
ReplaceNode(N, Mov);
3867+
}
3868+
38013869
/// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
38023870
/// conversion from \p SrcTy to \p DestTy.
38033871
unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7474
bool SelectSETP_F16X2(SDNode *N);
7575
bool SelectSETP_BF16X2(SDNode *N);
7676
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
77-
77+
void SelectV2I64toI128(SDNode *N);
78+
void SelectI128toV2I64(SDNode *N);
7879
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
7980
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
8081
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
859859
setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
860860
}
861861

862+
// Custom lowering for inline asm with 128-bit operands
863+
setOperationAction(ISD::CopyToReg, MVT::i128, Custom);
864+
setOperationAction(ISD::CopyFromReg, MVT::i128, Custom);
865+
862866
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
863867
// No FPOW or FREM in PTX.
864868

@@ -2804,6 +2808,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28042808
return LowerVectorArith(Op, DAG);
28052809
case ISD::DYNAMIC_STACKALLOC:
28062810
return LowerDYNAMIC_STACKALLOC(Op, DAG);
2811+
case ISD::CopyToReg:
2812+
return LowerCopyToReg_128(Op, DAG);
28072813
default:
28082814
llvm_unreachable("Custom lowering not defined for operation");
28092815
}
@@ -3094,6 +3100,54 @@ SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
30943100
return Result;
30953101
}
30963102

3103+
SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3104+
SelectionDAG &DAG) const {
3105+
// Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3106+
// operand so that it can pass the legalization.
3107+
3108+
assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3109+
"Custom lowering for 128-bit CopyToReg only");
3110+
3111+
SDNode *Node = Op.getNode();
3112+
SDLoc DL(Node);
3113+
3114+
SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
3115+
SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3116+
DAG.getIntPtrConstant(0, DL));
3117+
SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3118+
DAG.getIntPtrConstant(1, DL));
3119+
3120+
SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
3121+
SmallVector<EVT, 3> ResultsType(Node->values());
3122+
3123+
NewOps[0] = Op->getOperand(0); // Chain
3124+
NewOps[1] = Op->getOperand(1); // Dst Reg
3125+
NewOps[2] = Lo; // Lower 64-bit
3126+
NewOps[3] = Hi; // Higher 64-bit
3127+
if (Op.getNumOperands() == 4)
3128+
NewOps[4] = Op->getOperand(3); // Glue if exists
3129+
3130+
return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
3131+
}
3132+
3133+
unsigned NVPTXTargetLowering::getNumRegisters(
3134+
LLVMContext &Context, EVT VT,
3135+
std::optional<MVT> RegisterVT = std::nullopt) const {
3136+
if (VT == MVT::i128 && RegisterVT == MVT::i128)
3137+
return 1;
3138+
return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
3139+
}
3140+
3141+
bool NVPTXTargetLowering::splitValueIntoRegisterParts(
3142+
SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
3143+
unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
3144+
if (Val.getValueType() == MVT::i128 && NumParts == 1) {
3145+
Parts[0] = Val;
3146+
return true;
3147+
}
3148+
return false;
3149+
}
3150+
30973151
// This creates target external symbol for a function parameter.
30983152
// Name of the symbol is composed from its index and the function name.
30993153
// Negative index corresponds to special parameter (unsized array) used for
@@ -5150,6 +5204,7 @@ NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
51505204
case 'l':
51515205
case 'f':
51525206
case 'd':
5207+
case 'q':
51535208
case '0':
51545209
case 'N':
51555210
return C_RegisterClass;
@@ -5175,6 +5230,12 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
51755230
case 'l':
51765231
case 'N':
51775232
return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
5233+
case 'q': {
5234+
if (STI.getSmVersion() < 70)
5235+
report_fatal_error("Inline asm with 128 bit operands is only "
5236+
"supported for sm_70 and higher!");
5237+
return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
5238+
}
51785239
case 'f':
51795240
return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
51805241
case 'd':
@@ -6261,6 +6322,30 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
62616322
}
62626323
}
62636324

6325+
static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
6326+
SmallVectorImpl<SDValue> &Results) {
6327+
// Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
6328+
// result so that it can pass the legalization
6329+
SDLoc DL(N);
6330+
SDValue Chain = N->getOperand(0);
6331+
SDValue Reg = N->getOperand(1);
6332+
SDValue Glue = N->getOperand(2);
6333+
6334+
assert(Reg.getValueType() == MVT::i128 &&
6335+
"Custom lowering for CopyFromReg with 128-bit reg only");
6336+
SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(1),
6337+
N->getValueType(2)};
6338+
SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
6339+
6340+
SDValue NewValue = DAG.getNode(ISD::CopyFromReg, DL, ResultsType, NewOps);
6341+
SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
6342+
{NewValue.getValue(0), NewValue.getValue(1)});
6343+
6344+
Results.push_back(Pair);
6345+
Results.push_back(NewValue.getValue(2));
6346+
Results.push_back(NewValue.getValue(3));
6347+
}
6348+
62646349
void NVPTXTargetLowering::ReplaceNodeResults(
62656350
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
62666351
switch (N->getOpcode()) {
@@ -6272,6 +6357,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
62726357
case ISD::INTRINSIC_W_CHAIN:
62736358
ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
62746359
return;
6360+
case ISD::CopyFromReg:
6361+
ReplaceCopyFromReg_128(N, DAG, Results);
6362+
return;
62756363
}
62766364
}
62776365

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,14 @@ class NVPTXTargetLowering : public TargetLowering {
640640
SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
641641
SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
642642

643+
SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const;
644+
unsigned getNumRegisters(LLVMContext &Context, EVT VT,
645+
std::optional<MVT> RegisterVT) const override;
646+
bool
647+
splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
648+
SDValue *Parts, unsigned NumParts, MVT PartVT,
649+
std::optional<CallingConv::ID> CC) const override;
650+
643651
void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
644652
SelectionDAG &DAG) const override;
645653
SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
5151
} else if (DestRC == &NVPTX::Int64RegsRegClass) {
5252
Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64rr
5353
: NVPTX::BITCONVERT_64_F2I);
54+
} else if (DestRC == &NVPTX::Int128RegsRegClass) {
55+
Op = NVPTX::IMOV128rr;
5456
} else if (DestRC == &NVPTX::Float32RegsRegClass) {
5557
Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32rr
5658
: NVPTX::BITCONVERT_32_I2F);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,6 +2097,8 @@ let IsSimpleMove=1, hasSideEffects=0 in {
20972097
"mov.u32 \t$dst, $sss;", []>;
20982098
def IMOV64rr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$sss),
20992099
"mov.u64 \t$dst, $sss;", []>;
2100+
def IMOV128rr : NVPTXInst<(outs Int128Regs:$dst), (ins Int128Regs:$sss),
2101+
"mov.b128 \t$dst, $sss;", []>;
21002102

21012103
def IMOVB16rr : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$sss),
21022104
"mov.b16 \t$dst, $sss;", []>;
@@ -3545,6 +3547,9 @@ let hasSideEffects = false in {
35453547
def V2I32toI64 : NVPTXInst<(outs Int64Regs:$d),
35463548
(ins Int32Regs:$s1, Int32Regs:$s2),
35473549
"mov.b64 \t$d, {{$s1, $s2}};", []>;
3550+
def V2I64toI128 : NVPTXInst<(outs Int128Regs:$d),
3551+
(ins Int64Regs:$s1, Int64Regs:$s2),
3552+
"mov.b128 \t$d, {{$s1, $s2}};", []>;
35483553
def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
35493554
(ins Float32Regs:$s1, Float32Regs:$s2),
35503555
"mov.b64 \t$d, {{$s1, $s2}};", []>;
@@ -3560,6 +3565,9 @@ let hasSideEffects = false in {
35603565
def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
35613566
(ins Int64Regs:$s),
35623567
"mov.b64 \t{{$d1, $d2}}, $s;", []>;
3568+
def I128toV2I64: NVPTXInst<(outs Int64Regs:$d1, Int64Regs:$d2),
3569+
(ins Int128Regs:$s),
3570+
"mov.b128 \t{{$d1, $d2}}, $s;", []>;
35633571
def F64toV2F32 : NVPTXInst<(outs Float32Regs:$d1, Float32Regs:$d2),
35643572
(ins Float64Regs:$s),
35653573
"mov.b64 \t{{$d1, $d2}}, $s;", []>;
@@ -3629,7 +3637,7 @@ def : Pat<(i32 (ctlz (i32 Int32Regs:$a))), (CLZr32 Int32Regs:$a)>;
36293637
// ptx value to 64 bits to match the ISD node's semantics, unless we know we're
36303638
// truncating back down to 32 bits.
36313639
def : Pat<(i64 (ctlz Int64Regs:$a)), (CVT_u64_u32 (CLZr64 Int64Regs:$a), CvtNONE)>;
3632-
def : Pat<(i32 (trunc (ctlz Int64Regs:$a))), (CLZr64 Int64Regs:$a)>;
3640+
def : Pat<(i32 (trunc (i64 (ctlz Int64Regs:$a)))), (CLZr64 Int64Regs:$a)>;
36333641

36343642
// For 16-bit ctlz, we zero-extend to 32-bit, perform the count, then trunc the
36353643
// result back to 16-bits if necessary. We also need to subtract 16 because
@@ -3667,7 +3675,7 @@ def : Pat<(i32 (ctpop (i32 Int32Regs:$a))), (POPCr32 Int32Regs:$a)>;
36673675
// pattern that avoids the type conversion if we're truncating the result to
36683676
// i32 anyway.
36693677
def : Pat<(ctpop Int64Regs:$a), (CVT_u64_u32 (POPCr64 Int64Regs:$a), CvtNONE)>;
3670-
def : Pat<(i32 (trunc (ctpop Int64Regs:$a))), (POPCr64 Int64Regs:$a)>;
3678+
def : Pat<(i32 (trunc (i64 (ctpop Int64Regs:$a)))), (POPCr64 Int64Regs:$a)>;
36713679

36723680
// For 16-bit, we zero-extend to 32-bit, then trunc the result back to 16-bits.
36733681
// If we know that we're storing into an i32, we can avoid the final trunc.

0 commit comments

Comments
 (0)