Skip to content

Commit e08f1fd

Browse files
authored
[clang][SPIR-V] Always add convergence intrinsics (#88918)
PR #80680 added bits in the codegen to lazily add convergence intrinsics when required. This logic relied on the LoopStack. The issue is when parsing the condition, the loopstack doesn't yet reflect the correct values, as expected since we are not yet in the loop. However, convergence tokens should sometimes already be available. The solution which seemed the simplest is to greedily generate the tokens when we generate SPIR-V. Fixes #88144 --------- Signed-off-by: Nathan Gauër <[email protected]>
1 parent a4accdf commit e08f1fd

File tree

12 files changed

+580
-149
lines changed

12 files changed

+580
-149
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

+1-87
Original file line numberDiff line numberDiff line change
@@ -1141,91 +1141,8 @@ struct BitTest {
11411141
static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
11421142
};
11431143

1144-
// Returns the first convergence entry/loop/anchor instruction found in |BB|.
1145-
// std::nullptr otherwise.
1146-
llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
1147-
for (auto &I : *BB) {
1148-
auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
1149-
if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
1150-
return II;
1151-
}
1152-
return nullptr;
1153-
}
1154-
11551144
} // namespace
11561145

1157-
llvm::CallBase *
1158-
CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
1159-
llvm::Value *ParentToken) {
1160-
llvm::Value *bundleArgs[] = {ParentToken};
1161-
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
1162-
auto Output = llvm::CallBase::addOperandBundle(
1163-
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
1164-
Input->replaceAllUsesWith(Output);
1165-
Input->eraseFromParent();
1166-
return Output;
1167-
}
1168-
1169-
llvm::IntrinsicInst *
1170-
CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
1171-
llvm::Value *ParentToken) {
1172-
CGBuilderTy::InsertPoint IP = Builder.saveIP();
1173-
Builder.SetInsertPoint(&BB->front());
1174-
auto CB = Builder.CreateIntrinsic(
1175-
llvm::Intrinsic::experimental_convergence_loop, {}, {});
1176-
Builder.restoreIP(IP);
1177-
1178-
auto I = addConvergenceControlToken(CB, ParentToken);
1179-
return cast<llvm::IntrinsicInst>(I);
1180-
}
1181-
1182-
llvm::IntrinsicInst *
1183-
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
1184-
auto *BB = &F->getEntryBlock();
1185-
auto *token = getConvergenceToken(BB);
1186-
if (token)
1187-
return token;
1188-
1189-
// Adding a convergence token requires the function to be marked as
1190-
// convergent.
1191-
F->setConvergent();
1192-
1193-
CGBuilderTy::InsertPoint IP = Builder.saveIP();
1194-
Builder.SetInsertPoint(&BB->front());
1195-
auto I = Builder.CreateIntrinsic(
1196-
llvm::Intrinsic::experimental_convergence_entry, {}, {});
1197-
assert(isa<llvm::IntrinsicInst>(I));
1198-
Builder.restoreIP(IP);
1199-
1200-
return cast<llvm::IntrinsicInst>(I);
1201-
}
1202-
1203-
llvm::IntrinsicInst *
1204-
CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
1205-
assert(LI != nullptr);
1206-
1207-
auto *token = getConvergenceToken(LI->getHeader());
1208-
if (token)
1209-
return token;
1210-
1211-
llvm::IntrinsicInst *PII =
1212-
LI->getParent()
1213-
? emitConvergenceLoopToken(
1214-
LI->getHeader(), getOrEmitConvergenceLoopToken(LI->getParent()))
1215-
: getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
1216-
1217-
return emitConvergenceLoopToken(LI->getHeader(), PII);
1218-
}
1219-
1220-
llvm::CallBase *
1221-
CodeGenFunction::addControlledConvergenceToken(llvm::CallBase *Input) {
1222-
llvm::Value *ParentToken =
1223-
LoopStack.hasInfo()
1224-
? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
1225-
: getOrEmitConvergenceEntryToken(Input->getFunction());
1226-
return addConvergenceControlToken(Input, ParentToken);
1227-
}
1228-
12291146
BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
12301147
switch (BuiltinID) {
12311148
// Main portable variants.
@@ -18402,12 +18319,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1840218319
ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
1840318320
}
1840418321
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
18405-
auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
18322+
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
1840618323
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
1840718324
{}, false, true));
18408-
if (getTarget().getTriple().isSPIRVLogical())
18409-
CI = dyn_cast<CallInst>(addControlledConvergenceToken(CI));
18410-
return CI;
1841118325
}
1841218326
}
1841318327
return nullptr;

clang/lib/CodeGen/CGCall.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -4830,6 +4830,9 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
48304830
llvm::CallInst *call = Builder.CreateCall(
48314831
callee, args, getBundlesForFunclet(callee.getCallee()), name);
48324832
call->setCallingConv(getRuntimeCC());
4833+
4834+
if (CGM.shouldEmitConvergenceTokens() && call->isConvergent())
4835+
return addControlledConvergenceToken(call);
48334836
return call;
48344837
}
48354838

@@ -5730,7 +5733,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
57305733
if (!CI->getType()->isVoidTy())
57315734
CI->setName("call");
57325735

5733-
if (getTarget().getTriple().isSPIRVLogical() && CI->isConvergent())
5736+
if (CGM.shouldEmitConvergenceTokens() && CI->isConvergent())
57345737
CI = addControlledConvergenceToken(CI);
57355738

57365739
// Update largest vector width from the return type.

clang/lib/CodeGen/CGStmt.cpp

+93
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,10 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
978978
JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
979979
EmitBlock(LoopHeader.getBlock());
980980

981+
if (CGM.shouldEmitConvergenceTokens())
982+
ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
983+
LoopHeader.getBlock(), ConvergenceTokenStack.back()));
984+
981985
// Create an exit block for when the condition fails, which will
982986
// also become the break target.
983987
JumpDest LoopExit = getJumpDestInCurrentScope("while.end");
@@ -1079,6 +1083,9 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
10791083
// block.
10801084
if (llvm::EnableSingleByteCoverage)
10811085
incrementProfileCounter(&S);
1086+
1087+
if (CGM.shouldEmitConvergenceTokens())
1088+
ConvergenceTokenStack.pop_back();
10821089
}
10831090

10841091
void CodeGenFunction::EmitDoStmt(const DoStmt &S,
@@ -1098,6 +1105,11 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
10981105
EmitBlockWithFallThrough(LoopBody, S.getBody());
10991106
else
11001107
EmitBlockWithFallThrough(LoopBody, &S);
1108+
1109+
if (CGM.shouldEmitConvergenceTokens())
1110+
ConvergenceTokenStack.push_back(
1111+
emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
1112+
11011113
{
11021114
RunCleanupsScope BodyScope(*this);
11031115
EmitStmt(S.getBody());
@@ -1151,6 +1163,9 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
11511163
// block.
11521164
if (llvm::EnableSingleByteCoverage)
11531165
incrementProfileCounter(&S);
1166+
1167+
if (CGM.shouldEmitConvergenceTokens())
1168+
ConvergenceTokenStack.pop_back();
11541169
}
11551170

11561171
void CodeGenFunction::EmitForStmt(const ForStmt &S,
@@ -1170,6 +1185,10 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
11701185
llvm::BasicBlock *CondBlock = CondDest.getBlock();
11711186
EmitBlock(CondBlock);
11721187

1188+
if (CGM.shouldEmitConvergenceTokens())
1189+
ConvergenceTokenStack.push_back(
1190+
emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
1191+
11731192
const SourceRange &R = S.getSourceRange();
11741193
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
11751194
SourceLocToDebugLoc(R.getBegin()),
@@ -1279,6 +1298,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
12791298
// block.
12801299
if (llvm::EnableSingleByteCoverage)
12811300
incrementProfileCounter(&S);
1301+
1302+
if (CGM.shouldEmitConvergenceTokens())
1303+
ConvergenceTokenStack.pop_back();
12821304
}
12831305

12841306
void
@@ -1301,6 +1323,10 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
13011323
llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
13021324
EmitBlock(CondBlock);
13031325

1326+
if (CGM.shouldEmitConvergenceTokens())
1327+
ConvergenceTokenStack.push_back(
1328+
emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
1329+
13041330
const SourceRange &R = S.getSourceRange();
13051331
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
13061332
SourceLocToDebugLoc(R.getBegin()),
@@ -1369,6 +1395,9 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
13691395
// block.
13701396
if (llvm::EnableSingleByteCoverage)
13711397
incrementProfileCounter(&S);
1398+
1399+
if (CGM.shouldEmitConvergenceTokens())
1400+
ConvergenceTokenStack.pop_back();
13721401
}
13731402

13741403
void CodeGenFunction::EmitReturnOfRValue(RValue RV, QualType Ty) {
@@ -3158,3 +3187,67 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {
31583187

31593188
return F;
31603189
}
3190+
3191+
namespace {
3192+
// Returns the first convergence entry/loop/anchor instruction found in |BB|.
3193+
// std::nullptr otherwise.
3194+
llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
3195+
for (auto &I : *BB) {
3196+
auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
3197+
if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
3198+
return II;
3199+
}
3200+
return nullptr;
3201+
}
3202+
3203+
} // namespace
3204+
3205+
llvm::CallBase *
3206+
CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
3207+
llvm::Value *ParentToken) {
3208+
llvm::Value *bundleArgs[] = {ParentToken};
3209+
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
3210+
auto Output = llvm::CallBase::addOperandBundle(
3211+
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
3212+
Input->replaceAllUsesWith(Output);
3213+
Input->eraseFromParent();
3214+
return Output;
3215+
}
3216+
3217+
llvm::IntrinsicInst *
3218+
CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
3219+
llvm::Value *ParentToken) {
3220+
CGBuilderTy::InsertPoint IP = Builder.saveIP();
3221+
if (BB->empty())
3222+
Builder.SetInsertPoint(BB);
3223+
else
3224+
Builder.SetInsertPoint(BB->getFirstInsertionPt());
3225+
3226+
llvm::CallBase *CB = Builder.CreateIntrinsic(
3227+
llvm::Intrinsic::experimental_convergence_loop, {}, {});
3228+
Builder.restoreIP(IP);
3229+
3230+
llvm::CallBase *I = addConvergenceControlToken(CB, ParentToken);
3231+
return cast<llvm::IntrinsicInst>(I);
3232+
}
3233+
3234+
llvm::IntrinsicInst *
3235+
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
3236+
llvm::BasicBlock *BB = &F->getEntryBlock();
3237+
llvm::IntrinsicInst *Token = getConvergenceToken(BB);
3238+
if (Token)
3239+
return Token;
3240+
3241+
// Adding a convergence token requires the function to be marked as
3242+
// convergent.
3243+
F->setConvergent();
3244+
3245+
CGBuilderTy::InsertPoint IP = Builder.saveIP();
3246+
Builder.SetInsertPoint(&BB->front());
3247+
llvm::CallBase *I = Builder.CreateIntrinsic(
3248+
llvm::Intrinsic::experimental_convergence_entry, {}, {});
3249+
assert(isa<llvm::IntrinsicInst>(I));
3250+
Builder.restoreIP(IP);
3251+
3252+
return cast<llvm::IntrinsicInst>(I);
3253+
}

clang/lib/CodeGen/CodeGenFunction.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,12 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
353353
assert(DeferredDeactivationCleanupStack.empty() &&
354354
"mismatched activate/deactivate of cleanups!");
355355

356+
if (CGM.shouldEmitConvergenceTokens()) {
357+
ConvergenceTokenStack.pop_back();
358+
assert(ConvergenceTokenStack.empty() &&
359+
"mismatched push/pop in convergence stack!");
360+
}
361+
356362
bool OnlySimpleReturnStmts = NumSimpleReturnExprs > 0
357363
&& NumSimpleReturnExprs == NumReturnExprs
358364
&& ReturnBlock.getBlock()->use_empty();
@@ -1277,6 +1283,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
12771283
if (CurFuncDecl)
12781284
if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
12791285
LargestVectorWidth = VecWidth->getVectorWidth();
1286+
1287+
if (CGM.shouldEmitConvergenceTokens())
1288+
ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
12801289
}
12811290

12821291
void CodeGenFunction::EmitFunctionBody(const Stmt *Body) {

clang/lib/CodeGen/CodeGenFunction.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ class CodeGenFunction : public CodeGenTypeCache {
315315
/// Stack to track the Logical Operator recursion nest for MC/DC.
316316
SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
317317

318+
/// Stack to track the controlled convergence tokens.
319+
SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
320+
318321
/// Number of nested loop to be consumed by the last surrounding
319322
/// loop-associated directive.
320323
int ExpectedOMPLoopDepth = 0;
@@ -5076,7 +5079,11 @@ class CodeGenFunction : public CodeGenTypeCache {
50765079
const llvm::Twine &Name = "");
50775080
// Adds a convergence_ctrl token to |Input| and emits the required parent
50785081
// convergence instructions.
5079-
llvm::CallBase *addControlledConvergenceToken(llvm::CallBase *Input);
5082+
template <typename CallType>
5083+
CallType *addControlledConvergenceToken(CallType *Input) {
5084+
return cast<CallType>(
5085+
addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
5086+
}
50805087

50815088
private:
50825089
// Emits a convergence_loop instruction for the given |BB|, with |ParentToken|

clang/lib/CodeGen/CodeGenModule.h

+8
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,14 @@ class CodeGenModule : public CodeGenTypeCache {
15861586
void AddGlobalDtor(llvm::Function *Dtor, int Priority = 65535,
15871587
bool IsDtorAttrFunc = false);
15881588

1589+
// Return whether structured convergence intrinsics should be generated for
1590+
// this target.
1591+
bool shouldEmitConvergenceTokens() const {
1592+
// TODO: this should probably become unconditional once the controlled
1593+
// convergence becomes the norm.
1594+
return getTriple().isSPIRVLogical();
1595+
}
1596+
15891597
private:
15901598
llvm::Constant *GetOrCreateLLVMFunction(
15911599
StringRef MangledName, llvm::Type *Ty, GlobalDecl D, bool ForVTable,

clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
21
// RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
32

43
RWBuffer<float> Buf;

0 commit comments

Comments
 (0)