Skip to content

Commit c5dadbc

Browse files
committed
--Added legalizer for @llvm.lround.* and @llvm.llround.* inrinsic
--Added Instruction Selector for @llvm.lround.* and @llvm.llround.* intrinsic --Added tests for @llvm.lround.* and @llvm.llround.* intrinsic -- Corrected formatting for files in the PR -- Correction in test files --Corrected formatting for the files in the PR --Correction in test files
1 parent 89e7f4d commit c5dadbc

File tree

4 files changed

+278
-1
lines changed

4 files changed

+278
-1
lines changed

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

+62-1
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,12 @@ class SPIRVInstructionSelector : public InstructionSelector {
282282
GL::GLSLExtInst GLInst) const;
283283
bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
284284
MachineInstr &I, const ExtInstList &ExtInsts) const;
285+
bool selectExtInstForLRound(Register ResVReg, const SPIRVType *ResType,
286+
MachineInstr &I, CL::OpenCLExtInst CLInst,
287+
GL::GLSLExtInst GLInst) const;
288+
bool selectExtInstForLRound(Register ResVReg, const SPIRVType *ResType,
289+
MachineInstr &I,
290+
const ExtInstList &ExtInsts) const;
285291

286292
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
287293
MachineInstr &I) const;
@@ -622,7 +628,26 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
622628
return selectSUCmp(ResVReg, ResType, I, true);
623629
case TargetOpcode::G_UCMP:
624630
return selectSUCmp(ResVReg, ResType, I, false);
625-
631+
case TargetOpcode::G_LROUND:
632+
case TargetOpcode::G_LLROUND: {
633+
Register regForLround =
634+
MRI->createVirtualRegister(MRI->getRegClass(ResVReg), "lround");
635+
MRI->setRegClass(regForLround, &SPIRV::iIDRegClass);
636+
GR.assignSPIRVTypeToVReg(GR.getSPIRVTypeForVReg(I.getOperand(1).getReg()),
637+
regForLround, *(I.getParent()->getParent()));
638+
bool isRounded = selectExtInstForLRound(
639+
regForLround, GR.getSPIRVTypeForVReg(regForLround), I, CL::round,
640+
GL::Round);
641+
if (isRounded) {
642+
MachineBasicBlock &BB = *I.getParent();
643+
MachineFunction &MF = *BB.getParent();
644+
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConvertFToS))
645+
.addDef(ResVReg)
646+
.addUse(GR.getSPIRVTypeID(ResType))
647+
.addUse(regForLround);
648+
return MIB.constrainAllUses(TII, TRI, RBI);
649+
}
650+
}
626651
case TargetOpcode::G_STRICT_FMA:
627652
case TargetOpcode::G_FMA:
628653
return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
@@ -961,6 +986,42 @@ bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
961986
}
962987
return false;
963988
}
989+
bool SPIRVInstructionSelector::selectExtInstForLRound(
990+
Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
991+
CL::OpenCLExtInst CLInst, GL::GLSLExtInst GLInst) const {
992+
ExtInstList ExtInsts = {{SPIRV::InstructionSet::OpenCL_std, CLInst},
993+
{SPIRV::InstructionSet::GLSL_std_450, GLInst}};
994+
return selectExtInstForLRound(ResVReg, ResType, I, ExtInsts);
995+
}
996+
997+
bool SPIRVInstructionSelector::selectExtInstForLRound(
998+
Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
999+
const ExtInstList &Insts) const {
1000+
1001+
for (const auto &Ex : Insts) {
1002+
SPIRV::InstructionSet::InstructionSet Set = Ex.first;
1003+
uint32_t Opcode = Ex.second;
1004+
if (STI.canUseExtInstSet(Set)) {
1005+
MachineBasicBlock &BB = *I.getParent();
1006+
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
1007+
.addDef(ResVReg)
1008+
.addUse(GR.getSPIRVTypeID(ResType))
1009+
.addImm(static_cast<uint32_t>(Set))
1010+
.addImm(Opcode);
1011+
const unsigned NumOps = I.getNumOperands();
1012+
unsigned Index = 1;
1013+
if (Index < NumOps &&
1014+
I.getOperand(Index).getType() ==
1015+
MachineOperand::MachineOperandType::MO_IntrinsicID)
1016+
Index = 2;
1017+
for (; Index < NumOps; ++Index)
1018+
MIB.add(I.getOperand(Index));
1019+
MIB.constrainAllUses(TII, TRI, RBI);
1020+
return true;
1021+
}
1022+
}
1023+
return false;
1024+
}
9641025

9651026
bool SPIRVInstructionSelector::selectOpWithSrcs(Register ResVReg,
9661027
const SPIRVType *ResType,

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
305305
{G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
306306
.alwaysLegal();
307307

308+
getActionDefinitionsBuilder({G_LROUND, G_LLROUND})
309+
.legalForCartesianProduct(allFloatScalarsAndVectors,
310+
allIntScalarsAndVectors);
311+
308312
// FP conversions.
309313
getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
310314
.legalForCartesianProduct(allFloatScalarsAndVectors);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
5+
6+
; CHECK: [[opencl:%[0-9]+]] = OpExtInstImport "OpenCL.std"
7+
; CHECK-DAG: [[f32:%[0-9]+]] = OpTypeFloat 32
8+
; CHECK-DAG: [[i32:%[0-9]+]] = OpTypeInt 32 0
9+
; CHECK-DAG: [[f64:%[0-9]+]] = OpTypeFloat 64
10+
; CHECK-DAG: [[i64:%[0-9]+]] = OpTypeInt 64 0
11+
; CHECK-DAG: [[vecf32:%[0-9]+]] = OpTypeVector [[f32]]
12+
; CHECK-DAG: [[veci32:%[0-9]+]] = OpTypeVector [[i32]]
13+
; CHECK-DAG: [[vecf64:%[0-9]+]] = OpTypeVector [[f64]]
14+
; CHECK-DAG: [[veci64:%[0-9]+]] = OpTypeVector [[i64]]
15+
16+
; CHECK: [[rounded_i32_f32:%[0-9]+]] = OpExtInst [[f32]] [[opencl]] round %[[#]]
17+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[i32]] [[rounded_i32_f32]]
18+
; CHECK: [[rounded_i32_f64:%[0-9]+]] = OpExtInst [[f64]] [[opencl]] round %[[#]]
19+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[i32]] [[rounded_i32_f64]]
20+
; CHECK: [[rounded_i64_f32:%[0-9]+]] = OpExtInst [[f32]] [[opencl]] round %[[#]]
21+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[i64]] [[rounded_i64_f32]]
22+
; CHECK: [[rounded_i64_f64:%[0-9]+]] = OpExtInst [[f64]] [[opencl]] round %[[#]]
23+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[i64]] [[rounded_i64_f64]]
24+
; CHECK: [[rounded_v4i32_f32:%[0-9]+]] = OpExtInst [[vecf32]] [[opencl]] round %[[#]]
25+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[veci32]] [[rounded_v4i32_f32]]
26+
; CHECK: [[rounded_v4i32_f64:%[0-9]+]] = OpExtInst [[vecf64]] [[opencl]] round %[[#]]
27+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[veci32]] [[rounded_v4i32_f64]]
28+
; CHECK: [[rounded_v4i64_f32:%[0-9]+]] = OpExtInst [[vecf32]] [[opencl]] round %[[#]]
29+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[veci64]] [[rounded_v4i64_f32]]
30+
; CHECK: [[rounded_v4i64_f64:%[0-9]+]] = OpExtInst [[vecf64]] [[opencl]] round %[[#]]
31+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[veci64]] [[rounded_v4i64_f64]]
32+
33+
34+
define spir_func i32 @test_llround_i32_f32(float %arg0) {
35+
entry:
36+
37+
%0 = call i32 @llvm.llround.i32.f32(float %arg0)
38+
ret i32 %0
39+
}
40+
41+
define spir_func i32 @test_llround_i32_f64(double %arg0) {
42+
entry:
43+
44+
45+
%0 = call i32 @llvm.llround.i32.f64(double %arg0)
46+
ret i32 %0
47+
}
48+
49+
define spir_func i64 @test_llround_i64_f32(float %arg0) {
50+
entry:
51+
52+
53+
%0 = call i64 @llvm.llround.i64.f32(float %arg0)
54+
ret i64 %0
55+
}
56+
57+
define spir_func i64 @test_llround_i64_f64(double %arg0) {
58+
entry:
59+
60+
61+
%0 = call i64 @llvm.llround.i64.f64(double %arg0)
62+
ret i64 %0
63+
}
64+
65+
define spir_func <4 x i32> @test_llround_v4i32_f32(<4 x float> %arg0) {
66+
entry:
67+
68+
69+
%0 = call <4 x i32> @llvm.llround.v4i32.f32(<4 x float> %arg0)
70+
ret <4 x i32> %0
71+
}
72+
73+
74+
define spir_func <4 x i32> @test_llround_v4i32_f64(<4 x double> %arg0) {
75+
entry:
76+
77+
78+
%0 = call <4 x i32> @llvm.llround.v4i32.f64(<4 x double> %arg0)
79+
ret <4 x i32> %0
80+
}
81+
82+
define spir_func <4 x i64> @test_llround_v4i64_f32(<4 x float> %arg0) {
83+
entry:
84+
85+
86+
%0 = call <4 x i64> @llvm.llround.v4i64.f32(<4 x float> %arg0)
87+
ret <4 x i64> %0
88+
}
89+
90+
91+
define spir_func <4 x i64> @test_llround_v4i64_f64(<4 x double> %arg0) {
92+
entry:
93+
94+
%0 = call <4 x i64> @llvm.llround.v4i64.f64(<4 x double> %arg0)
95+
ret <4 x i64> %0
96+
}
97+
98+
declare i32 @llvm.llround.i32.f32(float)
99+
declare i32 @llvm.llround.i32.f64(double)
100+
declare i64 @llvm.llround.i64.f32(float)
101+
declare i64 @llvm.llround.i64.f64(double)
102+
103+
declare <4 x i32> @llvm.llround.v4i32.f32(<4 x float>)
104+
declare <4 x i32> @llvm.llround.v4i32.f64(<4 x double>)
105+
declare <4 x i64> @llvm.llround.v4i64.f32(<4 x float>)
106+
declare <4 x i64> @llvm.llround.v4i64.f64(<4 x double>)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
5+
6+
; CHECK: [[opencl:%[0-9]+]] = OpExtInstImport "OpenCL.std"
7+
; CHECK-DAG: [[f32:%[0-9]+]] = OpTypeFloat 32
8+
; CHECK-DAG: [[i32:%[0-9]+]] = OpTypeInt 32 0
9+
; CHECK-DAG: [[f64:%[0-9]+]] = OpTypeFloat 64
10+
; CHECK-DAG: [[i64:%[0-9]+]] = OpTypeInt 64 0
11+
; CHECK-DAG: [[vecf32:%[0-9]+]] = OpTypeVector [[f32]]
12+
; CHECK-DAG: [[veci32:%[0-9]+]] = OpTypeVector [[i32]]
13+
; CHECK-DAG: [[vecf64:%[0-9]+]] = OpTypeVector [[f64]]
14+
; CHECK-DAG: [[veci64:%[0-9]+]] = OpTypeVector [[i64]]
15+
16+
; CHECK: [[rounded_i32_f32:%[0-9]+]] = OpExtInst [[f32]] [[opencl]] round %[[#]]
17+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[i32]] [[rounded_i32_f32]]
18+
; CHECK: [[rounded_i32_f64:%[0-9]+]] = OpExtInst [[f64]] [[opencl]] round %[[#]]
19+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[i32]] [[rounded_i32_f64]]
20+
; CHECK: [[rounded_i64_f32:%[0-9]+]] = OpExtInst [[f32]] [[opencl]] round %[[#]]
21+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[i64]] [[rounded_i64_f32]]
22+
; CHECK: [[rounded_i64_f64:%[0-9]+]] = OpExtInst [[f64]] [[opencl]] round %[[#]]
23+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[i64]] [[rounded_i64_f64]]
24+
; CHECK: [[rounded_v4i32_f32:%[0-9]+]] = OpExtInst [[vecf32]] [[opencl]] round %[[#]]
25+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[veci32]] [[rounded_v4i32_f32]]
26+
; CHECK: [[rounded_v4i32_f64:%[0-9]+]] = OpExtInst [[vecf64]] [[opencl]] round %[[#]]
27+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[veci32]] [[rounded_v4i32_f64]]
28+
; CHECK: [[rounded_v4i64_f32:%[0-9]+]] = OpExtInst [[vecf32]] [[opencl]] round %[[#]]
29+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[veci64]] [[rounded_v4i64_f32]]
30+
; CHECK: [[rounded_v4i64_f64:%[0-9]+]] = OpExtInst [[vecf64]] [[opencl]] round %[[#]]
31+
; CHECK-NEXT: %[[#]] = OpConvertFToS [[veci64]] [[rounded_v4i64_f64]]
32+
33+
34+
define spir_func i32 @test_lround_i32_f32(float %arg0) {
35+
entry:
36+
37+
%0 = call i32 @llvm.lround.i32.f32(float %arg0)
38+
ret i32 %0
39+
}
40+
41+
define spir_func i32 @test_lround_i32_f64(double %arg0) {
42+
entry:
43+
44+
45+
%0 = call i32 @llvm.lround.i32.f64(double %arg0)
46+
ret i32 %0
47+
}
48+
49+
define spir_func i64 @test_lround_i64_f32(float %arg0) {
50+
entry:
51+
52+
53+
%0 = call i64 @llvm.lround.i64.f32(float %arg0)
54+
ret i64 %0
55+
}
56+
57+
define spir_func i64 @test_lround_i64_f64(double %arg0) {
58+
entry:
59+
60+
61+
%0 = call i64 @llvm.lround.i64.f64(double %arg0)
62+
ret i64 %0
63+
}
64+
65+
define spir_func <4 x i32> @test_lround_v4i32_f32(<4 x float> %arg0) {
66+
entry:
67+
68+
69+
%0 = call <4 x i32> @llvm.lround.v4i32.f32(<4 x float> %arg0)
70+
ret <4 x i32> %0
71+
}
72+
73+
74+
define spir_func <4 x i32> @test_lround_v4i32_f64(<4 x double> %arg0) {
75+
entry:
76+
77+
78+
%0 = call <4 x i32> @llvm.lround.v4i32.f64(<4 x double> %arg0)
79+
ret <4 x i32> %0
80+
}
81+
82+
define spir_func <4 x i64> @test_lround_v4i64_f32(<4 x float> %arg0) {
83+
entry:
84+
85+
86+
%0 = call <4 x i64> @llvm.lround.v4i64.f32(<4 x float> %arg0)
87+
ret <4 x i64> %0
88+
}
89+
90+
91+
define spir_func <4 x i64> @test_lround_v4i64_f64(<4 x double> %arg0) {
92+
entry:
93+
94+
%0 = call <4 x i64> @llvm.lround.v4i64.f64(<4 x double> %arg0)
95+
ret <4 x i64> %0
96+
}
97+
98+
declare i32 @llvm.lround.i32.f32(float)
99+
declare i32 @llvm.lround.i32.f64(double)
100+
declare i64 @llvm.lround.i64.f32(float)
101+
declare i64 @llvm.lround.i64.f64(double)
102+
103+
declare <4 x i32> @llvm.lround.v4i32.f32(<4 x float>)
104+
declare <4 x i32> @llvm.lround.v4i32.f64(<4 x double>)
105+
declare <4 x i64> @llvm.lround.v4i64.f32(<4 x float>)
106+
declare <4 x i64> @llvm.lround.v4i64.f64(<4 x double>)

0 commit comments

Comments
 (0)