@@ -859,6 +859,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
859
859
setBF16OperationAction (Op, MVT::v2bf16, Legal, Expand);
860
860
}
861
861
862
+ // Custom lowering for inline asm with 128-bit operands
863
+ setOperationAction (ISD::CopyToReg, MVT::i128 , Custom);
864
+ setOperationAction (ISD::CopyFromReg, MVT::i128 , Custom);
865
+
862
866
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
863
867
// No FPOW or FREM in PTX.
864
868
@@ -2804,6 +2808,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2804
2808
return LowerVectorArith (Op, DAG);
2805
2809
case ISD::DYNAMIC_STACKALLOC:
2806
2810
return LowerDYNAMIC_STACKALLOC (Op, DAG);
2811
+ case ISD::CopyToReg:
2812
+ return LowerCopyToReg_128 (Op, DAG);
2807
2813
default :
2808
2814
llvm_unreachable (" Custom lowering not defined for operation" );
2809
2815
}
@@ -3094,6 +3100,54 @@ SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3094
3100
return Result;
3095
3101
}
3096
3102
3103
+ SDValue NVPTXTargetLowering::LowerCopyToReg_128 (SDValue Op,
3104
+ SelectionDAG &DAG) const {
3105
+ // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3106
+ // operand so that it can pass the legalization.
3107
+
3108
+ assert (Op.getOperand (1 ).getValueType () == MVT::i128 &&
3109
+ " Custom lowering for 128-bit CopyToReg only" );
3110
+
3111
+ SDNode *Node = Op.getNode ();
3112
+ SDLoc DL (Node);
3113
+
3114
+ SDValue Cast = DAG.getBitcast (MVT::v2i64, Op->getOperand (2 ));
3115
+ SDValue Lo = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64 , Cast,
3116
+ DAG.getIntPtrConstant (0 , DL));
3117
+ SDValue Hi = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64 , Cast,
3118
+ DAG.getIntPtrConstant (1 , DL));
3119
+
3120
+ SmallVector<SDValue, 5 > NewOps (Op->getNumOperands () + 1 );
3121
+ SmallVector<EVT, 3 > ResultsType (Node->values ());
3122
+
3123
+ NewOps[0 ] = Op->getOperand (0 ); // Chain
3124
+ NewOps[1 ] = Op->getOperand (1 ); // Dst Reg
3125
+ NewOps[2 ] = Lo; // Lower 64-bit
3126
+ NewOps[3 ] = Hi; // Higher 64-bit
3127
+ if (Op.getNumOperands () == 4 )
3128
+ NewOps[4 ] = Op->getOperand (3 ); // Glue if exists
3129
+
3130
+ return DAG.getNode (ISD::CopyToReg, DL, ResultsType, NewOps);
3131
+ }
3132
+
3133
+ unsigned NVPTXTargetLowering::getNumRegisters (
3134
+ LLVMContext &Context, EVT VT,
3135
+ std::optional<MVT> RegisterVT = std::nullopt ) const {
3136
+ if (VT == MVT::i128 && RegisterVT == MVT::i128 )
3137
+ return 1 ;
3138
+ return TargetLoweringBase::getNumRegisters (Context, VT, RegisterVT);
3139
+ }
3140
+
3141
+ bool NVPTXTargetLowering::splitValueIntoRegisterParts (
3142
+ SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
3143
+ unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
3144
+ if (Val.getValueType () == MVT::i128 && NumParts == 1 ) {
3145
+ Parts[0 ] = Val;
3146
+ return true ;
3147
+ }
3148
+ return false ;
3149
+ }
3150
+
3097
3151
// This creates target external symbol for a function parameter.
3098
3152
// Name of the symbol is composed from its index and the function name.
3099
3153
// Negative index corresponds to special parameter (unsized array) used for
@@ -5150,6 +5204,7 @@ NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
5150
5204
case ' l' :
5151
5205
case ' f' :
5152
5206
case ' d' :
5207
+ case ' q' :
5153
5208
case ' 0' :
5154
5209
case ' N' :
5155
5210
return C_RegisterClass;
@@ -5175,6 +5230,12 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
5175
5230
case ' l' :
5176
5231
case ' N' :
5177
5232
return std::make_pair (0U , &NVPTX::Int64RegsRegClass);
5233
+ case ' q' : {
5234
+ if (STI.getSmVersion () < 70 )
5235
+ report_fatal_error (" Inline asm with 128 bit operands is only "
5236
+ " supported for sm_70 and higher!" );
5237
+ return std::make_pair (0U , &NVPTX::Int128RegsRegClass);
5238
+ }
5178
5239
case ' f' :
5179
5240
return std::make_pair (0U , &NVPTX::Float32RegsRegClass);
5180
5241
case ' d' :
@@ -6261,6 +6322,30 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
6261
6322
}
6262
6323
}
6263
6324
6325
+ static void ReplaceCopyFromReg_128 (SDNode *N, SelectionDAG &DAG,
6326
+ SmallVectorImpl<SDValue> &Results) {
6327
+ // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
6328
+ // result so that it can pass the legalization
6329
+ SDLoc DL (N);
6330
+ SDValue Chain = N->getOperand (0 );
6331
+ SDValue Reg = N->getOperand (1 );
6332
+ SDValue Glue = N->getOperand (2 );
6333
+
6334
+ assert (Reg.getValueType () == MVT::i128 &&
6335
+ " Custom lowering for CopyFromReg with 128-bit reg only" );
6336
+ SmallVector<EVT, 4 > ResultsType = {MVT::i64 , MVT::i64 , N->getValueType (1 ),
6337
+ N->getValueType (2 )};
6338
+ SmallVector<SDValue, 3 > NewOps = {Chain, Reg, Glue};
6339
+
6340
+ SDValue NewValue = DAG.getNode (ISD::CopyFromReg, DL, ResultsType, NewOps);
6341
+ SDValue Pair = DAG.getNode (ISD::BUILD_PAIR, DL, MVT::i128 ,
6342
+ {NewValue.getValue (0 ), NewValue.getValue (1 )});
6343
+
6344
+ Results.push_back (Pair);
6345
+ Results.push_back (NewValue.getValue (2 ));
6346
+ Results.push_back (NewValue.getValue (3 ));
6347
+ }
6348
+
6264
6349
void NVPTXTargetLowering::ReplaceNodeResults (
6265
6350
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
6266
6351
switch (N->getOpcode ()) {
@@ -6272,6 +6357,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
6272
6357
case ISD::INTRINSIC_W_CHAIN:
6273
6358
ReplaceINTRINSIC_W_CHAIN (N, DAG, Results);
6274
6359
return ;
6360
+ case ISD::CopyFromReg:
6361
+ ReplaceCopyFromReg_128 (N, DAG, Results);
6362
+ return ;
6275
6363
}
6276
6364
}
6277
6365
0 commit comments