diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h index 20834c5476468..3151d69ab745d 100644 --- a/llvm/lib/Target/SPIRV/SPIRV.h +++ b/llvm/lib/Target/SPIRV/SPIRV.h @@ -19,7 +19,7 @@ class SPIRVSubtarget; class InstructionSelector; class RegisterBankInfo; -ModulePass *createSPIRVPrepareFunctionsPass(); +ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM); FunctionPass *createSPIRVRegularizerPass(); FunctionPass *createSPIRVPreLegalizerPass(); FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM); diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 035989f2fe571..da61af7a669f1 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "MCTargetDesc/SPIRVBaseInfo.h" #include "MCTargetDesc/SPIRVMCTargetDesc.h" #include "SPIRV.h" #include "SPIRVGlobalRegistry.h" @@ -1407,15 +1408,17 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_alloca: return selectFrameIndex(ResVReg, ResType, I); case Intrinsic::spv_assume: - BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpAssumeTrueKHR)) - .addUse(I.getOperand(1).getReg()); + if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpAssumeTrueKHR)) + .addUse(I.getOperand(1).getReg()); break; case Intrinsic::spv_expect: - BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExpectKHR)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(I.getOperand(2).getReg()) - .addUse(I.getOperand(3).getReg()); + if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExpectKHR)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(2).getReg()) + .addUse(I.getOperand(3).getReg()); break; default: llvm_unreachable("Intrinsic selection not implemented"); diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp index 87a9a0e4fab84..c376497469ce3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp @@ -19,6 +19,7 @@ //===----------------------------------------------------------------------===// #include "SPIRV.h" +#include "SPIRVSubtarget.h" #include "SPIRVTargetMachine.h" #include "SPIRVUtils.h" #include "llvm/CodeGen/IntrinsicLowering.h" @@ -38,12 +39,13 @@ void initializeSPIRVPrepareFunctionsPass(PassRegistry &); namespace { class SPIRVPrepareFunctions : public ModulePass { + const SPIRVTargetMachine &TM; bool substituteIntrinsicCalls(Function *F); Function *removeAggregateTypesFromSignature(Function *F); public: static char ID; - SPIRVPrepareFunctions() : ModulePass(ID) { + SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) { initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry()); } @@ -300,7 +302,9 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { Changed = true; } else if (II->getIntrinsicID() == Intrinsic::assume || II->getIntrinsicID() == Intrinsic::expect) { - lowerExpectAssume(II); + const SPIRVSubtarget &STI = TM.getSubtarget(*F); + if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) + lowerExpectAssume(II); Changed = true; } } @@ -394,6 +398,7 @@ bool SPIRVPrepareFunctions::runOnModule(Module &M) { return Changed; } -ModulePass *llvm::createSPIRVPrepareFunctionsPass() { - return new SPIRVPrepareFunctions(); +ModulePass * +llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) { + return new SPIRVPrepareFunctions(TM); } diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp index 0c185f663b63f..cf6dfb127cdeb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp @@ -41,6 +41,10 @@ cl::list Extensions( "SPV_KHR_no_integer_wrap_decoration", "Adds decorations to indicate that a given instruction does " "not cause integer wrapping"), + clEnumValN(SPIRV::Extension::SPV_KHR_expect_assume, + "SPV_KHR_expect_assume", + "Provides additional information to a compiler, similar to " + "the llvm.assume and llvm.expect intrinsics."), clEnumValN(SPIRV::Extension::SPV_KHR_bit_instructions, "SPV_KHR_bit_instructions", "This enables bit instructions to be used by SPIR-V modules " diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp index 14dd429b45191..1503f263e42c0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -90,7 +90,7 @@ namespace { class SPIRVPassConfig : public TargetPassConfig { public: SPIRVPassConfig(SPIRVTargetMachine &TM, PassManagerBase &PM) - : TargetPassConfig(TM, PM) {} + : TargetPassConfig(TM, PM), TM(TM) {} SPIRVTargetMachine &getSPIRVTargetMachine() const { return getTM(); @@ -109,6 +109,9 @@ class SPIRVPassConfig : public TargetPassConfig { void addOptimizedRegAlloc() override {} void addPostRegAlloc() override; + +private: + const SPIRVTargetMachine &TM; }; } // namespace @@ -150,7 +153,7 @@ TargetPassConfig *SPIRVTargetMachine::createPassConfig(PassManagerBase &PM) { void SPIRVPassConfig::addIRPasses() { TargetPassConfig::addIRPasses(); addPass(createSPIRVRegularizerPass()); - addPass(createSPIRVPrepareFunctionsPass()); + addPass(createSPIRVPrepareFunctionsPass(TM)); } void SPIRVPassConfig::addISelPrepare() { diff --git a/llvm/test/CodeGen/SPIRV/assume.ll b/llvm/test/CodeGen/SPIRV/assume.ll index 679db5d88d4fb..6099955e4afb4 100644 --- a/llvm/test/CodeGen/SPIRV/assume.ll +++ b/llvm/test/CodeGen/SPIRV/assume.ll @@ -1,15 +1,20 @@ -; RUN: llc -mtriple=spirv32-unknown-unknown < %s | FileCheck %s -; RUN: llc -mtriple=spirv64-unknown-unknown < %s | FileCheck %s +; RUN: llc -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_KHR_expect_assume < %s | FileCheck --check-prefixes=EXT,CHECK %s +; RUN: llc -mtriple=spirv64-unknown-unknown --spirv-extensions=SPV_KHR_expect_assume < %s | FileCheck --check-prefixes=EXT,CHECK %s +; RUN: llc -mtriple=spirv32-unknown-unknown < %s | FileCheck --check-prefixes=NOEXT,CHECK %s +; RUN: llc -mtriple=spirv64-unknown-unknown < %s | FileCheck --check-prefixes=NOEXT,CHECK %s -; CHECK: OpCapability ExpectAssumeKHR -; CHECK-NEXT: OpExtension "SPV_KHR_expect_assume" +; EXT: OpCapability ExpectAssumeKHR +; EXT-NEXT: OpExtension "SPV_KHR_expect_assume" +; NOEXT-NOT: OpCapability ExpectAssumeKHR +; NOEXT-NOT: OpExtension "SPV_KHR_expect_assume" declare void @llvm.assume(i1) -; CHECK-DAG: %9 = OpIEqual %5 %6 %7 -; CHECK-NEXT: OpAssumeTrueKHR %9 -define void @assumeeq(i32 %x, i32 %y) { +; CHECK-DAG: %8 = OpIEqual %3 %5 %6 +; EXT: OpAssumeTrueKHR %8 +; NOEXT-NOT: OpAssumeTrueKHR %8 +define i1 @assumeeq(i32 %x, i32 %y) { %cmp = icmp eq i32 %x, %y call void @llvm.assume(i1 %cmp) - ret void + ret i1 %cmp } diff --git a/llvm/test/CodeGen/SPIRV/expect.ll b/llvm/test/CodeGen/SPIRV/expect.ll index 530ba7e5a49b0..51555cd155523 100644 --- a/llvm/test/CodeGen/SPIRV/expect.ll +++ b/llvm/test/CodeGen/SPIRV/expect.ll @@ -1,8 +1,12 @@ -; RUN: llc -mtriple=spirv32-unknown-unknown < %s | FileCheck %s -; RUN: llc -mtriple=spirv64-unknown-unknown < %s | FileCheck %s +; RUN: llc -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_KHR_expect_assume < %s | FileCheck --check-prefixes=CHECK,EXT %s +; RUN: llc -mtriple=spirv64-unknown-unknown --spirv-extensions=SPV_KHR_expect_assume < %s | FileCheck --check-prefixes=CHECK,EXT %s +; RUN: llc -mtriple=spirv32-unknown-unknown < %s | FileCheck --check-prefixes=CHECK,NOEXT %s +; RUN: llc -mtriple=spirv64-unknown-unknown < %s | FileCheck --check-prefixes=CHECK,NOEXT %s -; CHECK: OpCapability ExpectAssumeKHR -; CHECK-NEXT: OpExtension "SPV_KHR_expect_assume" +; EXT: OpCapability ExpectAssumeKHR +; EXT-NEXT: OpExtension "SPV_KHR_expect_assume" +; NOEXT-NOT: OpCapability ExpectAssumeKHR +; NOEXT-NOT: OpExtension "SPV_KHR_expect_assume" declare i32 @llvm.expect.i32(i32, i32) declare i32 @getOne() @@ -10,7 +14,8 @@ declare i32 @getOne() ; CHECK-DAG: %2 = OpTypeInt 32 0 ; CHECK-DAG: %6 = OpFunctionParameter %2 ; CHECK-DAG: %9 = OpIMul %2 %6 %8 -; CHECK-DAG: %10 = OpExpectKHR %2 %9 %6 +; EXT-DAG: %10 = OpExpectKHR %2 %9 %6 +; NOEXT-NOT: %10 = OpExpectKHR %2 %9 %6 define i32 @test(i32 %x) { %one = call i32 @getOne()