-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[Scalarizer][DirectX] support structs return types #111569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-backend-directx Author: Farzon Lotfi (farzonl) ChangesBased on this RFC: https://discourse.llvm.org/t/rfc-allow-the-scalarizer-pass-to-scalarize-vectors-returned-in-structs/82306 LLVM intrinsics do not support out params. To get around this limitation implementers will make intrinsics return structs to capture a return type and an out param. This implementation detail should not impact scalarization since these cases should be elementwise operations. Three changes are needed.
Testing changes
Full diff: https://github.com/llvm/llvm-project/pull/111569.diff 5 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f2b9e286ebb476..5f0f856df8e2b0 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -86,5 +86,7 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>],
+ [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
}
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index be714b5c87895a..4ddf39a4337df6 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -28,6 +28,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
switch (ID) {
case Intrinsic::dx_frac:
case Intrinsic::dx_rsqrt:
+ case Intrinsic::dx_splitdouble:
return true;
default:
return false;
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 72728c0f839e5d..d8b052061c1ad5 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -197,6 +197,23 @@ struct VectorLayout {
uint64_t SplitSize = 0;
};
+static bool isStructAllVectors(Type *Ty) {
+ if (!isa<StructType>(Ty))
+ return false;
+ if (Ty->getNumContainedTypes() < 1)
+ return false;
+ FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(0));
+ if (!VecTy)
+ return false;
+ unsigned VecSize = VecTy->getNumElements();
+ for (unsigned I = 1; I < Ty->getNumContainedTypes(); I++) {
+ VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(I));
+ if (!VecTy || VecSize != VecTy->getNumElements())
+ return false;
+ }
+ return true;
+}
+
/// Concatenate the given fragments to a single vector value of the type
/// described in @p VS.
static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
@@ -276,6 +293,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
bool visitBitCastInst(BitCastInst &BCI);
bool visitInsertElementInst(InsertElementInst &IEI);
bool visitExtractElementInst(ExtractElementInst &EEI);
+ bool visitExtractValueInst(ExtractValueInst &EVI);
bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
bool visitPHINode(PHINode &PHI);
bool visitLoadInst(LoadInst &LI);
@@ -667,6 +685,11 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
if (isTriviallyVectorizable(ID))
return true;
+ // TODO: investigate vectorizable frexp
+ switch (ID) {
+ case Intrinsic::frexp:
+ return true;
+ }
return Intrinsic::isTargetIntrinsic(ID) &&
TTI->isTargetIntrinsicTriviallyScalarizable(ID);
}
@@ -674,7 +697,13 @@ bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
/// If a call to a vector typed intrinsic function, split into a scalar call per
/// element if possible for the intrinsic.
bool ScalarizerVisitor::splitCall(CallInst &CI) {
- std::optional<VectorSplit> VS = getVectorSplit(CI.getType());
+ Type *CallType = CI.getType();
+ bool AreAllVectors = isStructAllVectors(CallType);
+ std::optional<VectorSplit> VS;
+ if (AreAllVectors)
+ VS = getVectorSplit(CallType->getContainedType(0));
+ else
+ VS = getVectorSplit(CallType);
if (!VS)
return false;
@@ -699,6 +728,18 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
Tys.push_back(VS->SplitTy);
+ if (AreAllVectors) {
+ Type *PrevType = CallType->getContainedType(0);
+ Type *CallType = CI.getType();
+ for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
+ Type *CurrType = cast<FixedVectorType>(CallType->getContainedType(I));
+ if (PrevType != CurrType) {
+ std::optional<VectorSplit> CurrVS = getVectorSplit(CurrType);
+ Tys.push_back(CurrVS->SplitTy);
+ PrevType = CurrType;
+ }
+ }
+ }
// Assumes that any vector type has the same number of elements as the return
// vector type, which is true for all current intrinsics.
for (unsigned I = 0; I != NumArgs; ++I) {
@@ -1029,6 +1070,31 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
return true;
}
+bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
+ Value *Op = EVI.getOperand(0);
+ Type *OpTy = Op->getType();
+ ValueVector Res;
+ if (!isStructAllVectors(OpTy))
+ return false;
+ Type *VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
+ std::optional<VectorSplit> VS = getVectorSplit(VecType);
+ if (!VS)
+ return false;
+ IRBuilder<> Builder(&EVI);
+ Scatterer Op0 = scatter(&EVI, Op, *VS);
+ assert(!EVI.getIndices().empty() && "Make sure an index exists");
+ // Note for our use case we only care about the top level index.
+ unsigned Index = EVI.getIndices()[0];
+ for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) {
+ Value *ResElem = Builder.CreateExtractValue(
+ Op0[OpIdx], Index, EVI.getName() + ".elem" + std::to_string(Index));
+ Res.push_back(ResElem);
+ }
+
+ gather(&EVI, Res, *VS);
+ return true;
+}
+
bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
if (!VS)
@@ -1195,7 +1261,7 @@ bool ScalarizerVisitor::finish() {
if (!Op->use_empty()) {
// The value is still needed, so recreate it using a series of
// insertelements and/or shufflevectors.
- Value *Res;
+ Value *Res = nullptr;
if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
BasicBlock *BB = Op->getParent();
IRBuilder<> Builder(Op);
@@ -1208,6 +1274,35 @@ bool ScalarizerVisitor::finish() {
Res = concatenate(Builder, CV, VS, Op->getName());
Res->takeName(Op);
+ } else if (auto *Ty = dyn_cast<StructType>(Op->getType())) {
+ BasicBlock *BB = Op->getParent();
+ IRBuilder<> Builder(Op);
+ if (isa<PHINode>(Op))
+ Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+
+ // Iterate over each element in the struct
+ unsigned NumOfStructElements = Ty->getNumElements();
+ SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
+ for (unsigned I = 0; I < NumOfStructElements; ++I) {
+ for (auto *CVelem : CV) {
+ Value *Elem = Builder.CreateExtractValue(
+ CVelem, I, Op->getName() + ".elem" + std::to_string(I));
+ ElemCV[I].push_back(Elem);
+ }
+ }
+ Res = PoisonValue::get(Ty);
+ for (unsigned I = 0; I < NumOfStructElements; ++I) {
+ Type *ElemTy = Ty->getElementType(I);
+ assert(isa<FixedVectorType>(ElemTy) &&
+ "Only Structs of all FixedVectorType supported");
+ VectorSplit VS = *getVectorSplit(ElemTy);
+ assert(VS.NumFragments == CV.size());
+
+ Value *ConcatenatedVector =
+ concatenate(Builder, ElemCV[I], VS, Op->getName());
+ Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I,
+ Op->getName() + ".insert");
+ }
} else {
assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
Res = CV[0];
diff --git a/llvm/test/CodeGen/DirectX/split-double.ll b/llvm/test/CodeGen/DirectX/split-double.ll
new file mode 100644
index 00000000000000..9b70e87ba4794e
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/split-double.ll
@@ -0,0 +1,40 @@
+; RUN: opt -passes='function(scalarizer<load-store>)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; CHECK-LABEL: @test_vector_double_split_void
+define void @test_vector_double_split_void(<2 x double> noundef %d) {
+ ; CHECK: [[ee0:%.*]] = extractelement <2 x double> %d, i64 0
+ ; CHECK: [[ie0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee0]])
+ ; CHECK: [[ee1:%.*]] = extractelement <2 x double> %d, i64 1
+ ; CHECK: [[ie1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee1]])
+ ; CHECK-NOT: extractvalue { i32, i32 } {{.*}}, 0
+ ; CHECK-NOT: insertelement <2 x i32> {{.*}}, i32 {{.*}}, i64 0
+ %hlsl.asuint = call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> %d)
+ ret void
+}
+
+; CHECK-LABEL: @test_vector_double_split
+define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
+ ; CHECK: [[ee0:%.*]] = extractelement <3 x double> %d, i64 0
+ ; CHECK: [[ie0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee0]])
+ ; CHECK: [[ee1:%.*]] = extractelement <3 x double> %d, i64 1
+ ; CHECK: [[ie1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee1]])
+ ; CHECK: [[ee2:%.*]] = extractelement <3 x double> %d, i64 2
+ ; CHECK: [[ie2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee2]])
+ ; CHECK: [[ev00:%.*]] = extractvalue { i32, i32 } [[ie0]], 0
+ ; CHECK: [[ev01:%.*]] = extractvalue { i32, i32 } [[ie1]], 0
+ ; CHECK: [[ev02:%.*]] = extractvalue { i32, i32 } [[ie2]], 0
+ ; CHECK: [[ev10:%.*]] = extractvalue { i32, i32 } [[ie0]], 1
+ ; CHECK: [[ev11:%.*]] = extractvalue { i32, i32 } [[ie1]], 1
+ ; CHECK: [[ev12:%.*]] = extractvalue { i32, i32 } [[ie2]], 1
+ ; CHECK: [[add1:%.*]] = add i32 [[ev00]], [[ev10]]
+ ; CHECK: [[add2:%.*]] = add i32 [[ev01]], [[ev11]]
+ ; CHECK: [[add3:%.*]] = add i32 [[ev02]], [[ev12]]
+ ; CHECK: insertelement <3 x i32> poison, i32 [[add1]], i64 0
+ ; CHECK: insertelement <3 x i32> %{{.*}}, i32 [[add2]], i64 1
+ ; CHECK: insertelement <3 x i32> %{{.*}}, i32 [[add3]], i64 2
+ %hlsl.asuint = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %d)
+ %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
+ %2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
+ %3 = add <3 x i32> %1, %2
+ ret <3 x i32> %3
+}
diff --git a/llvm/test/Transforms/Scalarizer/frexp.ll b/llvm/test/Transforms/Scalarizer/frexp.ll
new file mode 100644
index 00000000000000..48159b45c18960
--- /dev/null
+++ b/llvm/test/Transforms/Scalarizer/frexp.ll
@@ -0,0 +1,67 @@
+; RUN: opt %s -passes='function(scalarizer<load-store>)' -S | FileCheck %s
+
+; CHECK-LABEL: @test_vector_half_frexp_half
+define noundef <2 x half> @test_vector_half_frexp_half(<2 x half> noundef %h) {
+ ; CHECK: [[ee0:%.*]] = extractelement <2 x half> %h, i64 0
+ ; CHECK-NEXT: [[ie0:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee0]])
+ ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x half> %h, i64 1
+ ; CHECK-NEXT: [[ie1:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee1]])
+ ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { half, i32 } [[ie0]], 0
+ ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { half, i32 } [[ie1]], 0
+ ; CHECK-NEXT: insertelement <2 x half> poison, half [[ev00]], i64 0
+ ; CHECK-NEXT: insertelement <2 x half> %{{.*}}, half [[ev01]], i64 1
+ %r = call { <2 x half>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x half> %h)
+ %e0 = extractvalue { <2 x half>, <2 x i32> } %r, 0
+ ret <2 x half> %e0
+}
+
+; CHECK-LABEL: @test_vector_half_frexp_int
+define noundef <2 x i32> @test_vector_half_frexp_int(<2 x half> noundef %h) {
+ ; CHECK: [[ee0:%.*]] = extractelement <2 x half> %h, i64 0
+ ; CHECK-NEXT: [[ie0:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee0]])
+ ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x half> %h, i64 1
+ ; CHECK-NEXT: [[ie1:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee1]])
+ ; CHECK-NEXT: [[ev10:%.*]] = extractvalue { half, i32 } [[ie0]], 1
+ ; CHECK-NEXT: [[ev11:%.*]] = extractvalue { half, i32 } [[ie1]], 1
+ ; CHECK-NEXT: insertelement <2 x i32> poison, i32 [[ev10]], i64 0
+ ; CHECK-NEXT: insertelement <2 x i32> %{{.*}}, i32 [[ev11]], i64 1
+ %r = call { <2 x half>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x half> %h)
+ %e1 = extractvalue { <2 x half>, <2 x i32> } %r, 1
+ ret <2 x i32> %e1
+}
+
+; CHECK-LABEL: @test_vector_float_frexp_int
+define noundef <2 x float> @test_vector_float_frexp_int(<2 x float> noundef %f) {
+ ; CHECK: [[ee0:%.*]] = extractelement <2 x float> %f, i64 0
+ ; CHECK-NEXT: [[ie0:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee0]])
+ ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x float> %f, i64 1
+ ; CHECK-NEXT: [[ie1:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee1]])
+ ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { float, i32 } [[ie0]], 0
+ ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { float, i32 } [[ie1]], 0
+ ; CHECK-NEXT: insertelement <2 x float> poison, float [[ev00]], i64 0
+ ; CHECK-NEXT: insertelement <2 x float> %{{.*}}, float [[ev01]], i64 1
+ ; CHECK-NEXT: extractvalue { float, i32 } [[ie0]], 1
+ ; CHECK-NEXT: extractvalue { float, i32 } [[ie1]], 1
+ %1 = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f16.v2i32(<2 x float> %f)
+ %2 = extractvalue { <2 x float>, <2 x i32> } %1, 0
+ %3 = extractvalue { <2 x float>, <2 x i32> } %1, 1
+ ret <2 x float> %2
+}
+
+; CHECK-LABEL: @test_vector_double_frexp_int
+define noundef <2 x double> @test_vector_double_frexp_int(<2 x double> noundef %d) {
+ ; CHECK: [[ee0:%.*]] = extractelement <2 x double> %d, i64 0
+ ; CHECK-NEXT: [[ie0:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee0]])
+ ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x double> %d, i64 1
+ ; CHECK-NEXT: [[ie1:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee1]])
+ ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { double, i32 } [[ie0]], 0
+ ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { double, i32 } [[ie1]], 0
+ ; CHECK-NEXT: insertelement <2 x double> poison, double [[ev00]], i64 0
+ ; CHECK-NEXT: insertelement <2 x double> %{{.*}}, double [[ev01]], i64 1
+ ; CHECK-NEXT: extractvalue { double, i32 } [[ie0]], 1
+ ; CHECK-NEXT: extractvalue { double, i32 } [[ie1]], 1
+ %1 = call { <2 x double>, <2 x i32> } @llvm.frexp.v2f64.v2i32(<2 x double> %d)
+ %2 = extractvalue { <2 x double>, <2 x i32> } %1, 0
+ %3 = extractvalue { <2 x double>, <2 x i32> } %1, 1
+ ret <2 x double> %2
+}
|
@nikic do you have some time do to a review? |
9f94e81
to
aeba58e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't review this in detail, but looks sensible.
aeba58e
to
4b9b203
Compare
4b9b203
to
7643c8a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I approved after thinking you had accepted my suggested change, so I'd say that approval should be contingent on resolving the last comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
829583d
to
e00e740
Compare
Hi @farzonl I'm having problems when compiling:
Are you planning to add |
@mariusz-sikora-at-amd I wasn't aware of I'm a little suprised you are getting a casting error though. The only non dyn_casts are on |
At the beginning I thought that this is AMDGPU specific issue, but I also build AArch64 and used existing test and it failed also. This test is running
|
This is rather assert while we are doing |
Thats a bug. Adding to |
@mariusz-sikora-at-amd this bug will be fixed by: #113625. |
Based on this RFC: https://discourse.llvm.org/t/rfc-allow-the-scalarizer-pass-to-scalarize-vectors-returned-in-structs/82306
LLVM intrinsics do not support out params. To get around this limitation implementers will make intrinsics return structs to capture a return type and an out param. This implementation detail should not impact scalarization since these cases should be elementwise operations.
Three changes are needed.
ExtractValue
instructionsTesting changes
llvm.frexp
llvm.dx.splitdouble
fixes #111437