Skip to content

Commit f404047

Browse files
authored
[DirectX][OpLowering] Simplify named struct handling (#128247)
This removes "replaceFunctionWithNamedStructOp" and folds its functionality into "replaceFunctionWithOp". It turns out we were overcomplicating things and this is trivial to handle generically. Fixes #113192
1 parent 75bb25b commit f404047

File tree

4 files changed

+30
-61
lines changed

4 files changed

+30
-61
lines changed

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ def MakeDouble : DXILOp<101, makeDouble> {
930930

931931
def SplitDouble : DXILOp<102, splitDouble> {
932932
let Doc = "Splits a double into 2 uints";
933+
let intrinsics = [IntrinSelect<int_dx_splitdouble>];
933934
let arguments = [OverloadTy];
934935
let result = SplitDoubleTy;
935936
let overloads = [Overloads<DXIL1_0, [DoubleTy]>];

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,10 +535,6 @@ StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
535535
return ::getResRetType(ElementTy);
536536
}
537537

538-
StructType *DXILOpBuilder::getSplitDoubleType(LLVMContext &Context) {
539-
return ::getSplitDoubleType(Context);
540-
}
541-
542538
StructType *DXILOpBuilder::getHandleType() {
543539
return ::getHandleType(IRB.getContext());
544540
}

llvm/lib/Target/DirectX/DXILOpBuilder.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@ class DXILOpBuilder {
5050
/// Get a `%dx.types.ResRet` type with the given element type.
5151
StructType *getResRetType(Type *ElementTy);
5252

53-
/// Get the `%dx.types.splitdouble` type.
54-
StructType *getSplitDoubleType(LLVMContext &Context);
55-
5653
/// Get the `%dx.types.Handle` type.
5754
StructType *getHandleType();
5855

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 29 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,30 @@ class OpLowerer {
120120
int Value;
121121
};
122122

123+
/// Replaces uses of a struct with uses of an equivalent named struct.
124+
///
125+
/// DXIL operations that return structs give them well known names, so we need
126+
/// to update uses when we switch from an LLVM intrinsic to an op.
127+
Error replaceNamedStructUses(CallInst *Intrin, CallInst *DXILOp) {
128+
auto *IntrinTy = cast<StructType>(Intrin->getType());
129+
auto *DXILOpTy = cast<StructType>(DXILOp->getType());
130+
if (!IntrinTy->isLayoutIdentical(DXILOpTy))
131+
return make_error<StringError>(
132+
"Type mismatch between intrinsic and DXIL op",
133+
inconvertibleErrorCode());
134+
135+
for (Use &U : make_early_inc_range(Intrin->uses()))
136+
if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser()))
137+
EVI->setOperand(0, DXILOp);
138+
else if (auto *IVI = dyn_cast<InsertValueInst>(U.getUser()))
139+
IVI->setOperand(0, DXILOp);
140+
else
141+
return make_error<StringError>("DXIL ops that return structs may only "
142+
"be used by insert- and extractvalue",
143+
inconvertibleErrorCode());
144+
return Error::success();
145+
}
146+
123147
[[nodiscard]] bool
124148
replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
125149
ArrayRef<IntrinArgSelect> ArgSelects) {
@@ -154,32 +178,13 @@ class OpLowerer {
154178
if (Error E = OpCall.takeError())
155179
return E;
156180

157-
CI->replaceAllUsesWith(*OpCall);
158-
CI->eraseFromParent();
159-
return Error::success();
160-
});
161-
}
162-
163-
[[nodiscard]] bool replaceFunctionWithNamedStructOp(
164-
Function &F, dxil::OpCode DXILOp, Type *NewRetTy,
165-
llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) {
166-
bool IsVectorArgExpansion = isVectorArgExpansion(F);
167-
return replaceFunction(F, [&](CallInst *CI) -> Error {
168-
SmallVector<Value *> Args;
169-
OpBuilder.getIRB().SetInsertPoint(CI);
170-
if (IsVectorArgExpansion) {
171-
SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
172-
Args.append(NewArgs.begin(), NewArgs.end());
181+
if (isa<StructType>(CI->getType())) {
182+
if (Error E = replaceNamedStructUses(CI, *OpCall))
183+
return E;
173184
} else
174-
Args.append(CI->arg_begin(), CI->arg_end());
175-
176-
Expected<CallInst *> OpCall =
177-
OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy);
178-
if (Error E = OpCall.takeError())
179-
return E;
180-
if (Error E = ReplaceUses(CI, *OpCall))
181-
return E;
185+
CI->replaceAllUsesWith(*OpCall);
182186

187+
CI->eraseFromParent();
183188
return Error::success();
184189
});
185190
}
@@ -359,26 +364,6 @@ class OpLowerer {
359364
return lowerToBindAndAnnotateHandle(F);
360365
}
361366

362-
Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
363-
for (Use &U : make_early_inc_range(Intrin->uses())) {
364-
if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
365-
366-
if (EVI->getNumIndices() != 1)
367-
return createStringError(std::errc::invalid_argument,
368-
"Splitdouble has only 2 elements");
369-
EVI->setOperand(0, Op);
370-
} else {
371-
return make_error<StringError>(
372-
"Splitdouble use is not ExtractValueInst",
373-
inconvertibleErrorCode());
374-
}
375-
}
376-
377-
Intrin->eraseFromParent();
378-
379-
return Error::success();
380-
}
381-
382367
/// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
383368
/// Since we expect to be post-scalarization, make an effort to avoid vectors.
384369
Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
@@ -814,16 +799,6 @@ class OpLowerer {
814799
case Intrinsic::dx_resource_updatecounter:
815800
HasErrors |= lowerUpdateCounter(F);
816801
break;
817-
// TODO: this can be removed when
818-
// https://github.com/llvm/llvm-project/issues/113192 is fixed
819-
case Intrinsic::dx_splitdouble:
820-
HasErrors |= replaceFunctionWithNamedStructOp(
821-
F, OpCode::SplitDouble,
822-
OpBuilder.getSplitDoubleType(M.getContext()),
823-
[&](CallInst *CI, CallInst *Op) {
824-
return replaceSplitDoubleCallUsages(CI, Op);
825-
});
826-
break;
827802
case Intrinsic::ctpop:
828803
HasErrors |= lowerCtpopToCountBits(F);
829804
break;

0 commit comments

Comments
 (0)