Skip to content

Commit 45ad88d

Browse files
farzonlKornevNikita
authored andcommitted
[Scalarizer][DirectX] support structs return types (#111569)
Based on this RFC: https://discourse.llvm.org/t/rfc-allow-the-scalarizer-pass-to-scalarize-vectors-returned-in-structs/82306 LLVM intrinsics do not support out params. To get around this limitation implementers will make intrinsics return structs to capture a return type and an out param. This implementation detail should not impact scalarization since these cases should be elementwise operations. ## Three changes are needed. - The CallInst visitor needs to be updated to handle Structs - A new visitor is needed for `ExtractValue` instructions - finsh needs to be update to handle structs so that insert elements are properly propogated. ## Testing changes - Add support for `llvm.frexp` - Add support for `llvm.dx.splitdouble` fixes llvm/llvm-project#111437
1 parent 99e7986 commit 45ad88d

File tree

7 files changed

+297
-39
lines changed

7 files changed

+297
-39
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
154154
/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
155155
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
156156

157+
/// Identifies if the vector form of the intrinsic that returns a struct is
158+
/// overloaded at the struct element index \p RetIdx.
159+
bool isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
160+
int RetIdx);
161+
157162
/// Returns intrinsic ID for call.
158163
/// For the input call instruction it finds mapping intrinsic and returns
159164
/// its intrinsic ID, in case it does not found it return not_intrinsic.

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,7 @@ def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrCon
9494
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
9595
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
9696
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
97+
def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>],
98+
[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
9799
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
98100
}

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,16 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
152152
}
153153
}
154154

155+
bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
156+
int RetIdx) {
157+
switch (ID) {
158+
case Intrinsic::frexp:
159+
return RetIdx == 0 || RetIdx == 1;
160+
default:
161+
return RetIdx == 0;
162+
}
163+
}
164+
155165
/// Returns intrinsic ID for call.
156166
/// For the input call instruction it finds mapping intrinsic and returns
157167
/// its ID, in case it does not found it return not_intrinsic.
Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,39 @@
1-
//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
2-
//-*-===//
3-
//
4-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5-
// See https://llvm.org/LICENSE.txt for license information.
6-
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7-
//
8-
//===----------------------------------------------------------------------===//
9-
///
10-
//===----------------------------------------------------------------------===//
11-
12-
#include "DirectXTargetTransformInfo.h"
13-
#include "llvm/IR/Intrinsics.h"
14-
#include "llvm/IR/IntrinsicsDirectX.h"
15-
16-
using namespace llvm;
17-
18-
bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
19-
unsigned ScalarOpdIdx) {
20-
switch (ID) {
21-
case Intrinsic::dx_wave_readlane:
22-
return ScalarOpdIdx == 1;
23-
default:
24-
return false;
25-
}
26-
}
27-
28-
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
29-
Intrinsic::ID ID) const {
30-
switch (ID) {
31-
case Intrinsic::dx_frac:
32-
case Intrinsic::dx_rsqrt:
33-
case Intrinsic::dx_wave_readlane:
34-
return true;
35-
default:
36-
return false;
37-
}
38-
}
1+
//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
///
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "DirectXTargetTransformInfo.h"
13+
#include "llvm/IR/Intrinsics.h"
14+
#include "llvm/IR/IntrinsicsDirectX.h"
15+
16+
using namespace llvm;
17+
18+
bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
19+
unsigned ScalarOpdIdx) {
20+
switch (ID) {
21+
case Intrinsic::dx_wave_readlane:
22+
return ScalarOpdIdx == 1;
23+
default:
24+
return false;
25+
}
26+
}
27+
28+
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
29+
Intrinsic::ID ID) const {
30+
switch (ID) {
31+
case Intrinsic::dx_frac:
32+
case Intrinsic::dx_rsqrt:
33+
case Intrinsic::dx_wave_readlane:
34+
case Intrinsic::dx_splitdouble:
35+
return true;
36+
default:
37+
return false;
38+
}
39+
}

llvm/lib/Transforms/Scalar/Scalarizer.cpp

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,24 @@ struct VectorLayout {
197197
uint64_t SplitSize = 0;
198198
};
199199

200+
static bool isStructOfMatchingFixedVectors(Type *Ty) {
201+
if (!isa<StructType>(Ty))
202+
return false;
203+
unsigned StructSize = Ty->getNumContainedTypes();
204+
if (StructSize < 1)
205+
return false;
206+
FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(0));
207+
if (!VecTy)
208+
return false;
209+
unsigned VecSize = VecTy->getNumElements();
210+
for (unsigned I = 1; I < StructSize; I++) {
211+
VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(I));
212+
if (!VecTy || VecSize != VecTy->getNumElements())
213+
return false;
214+
}
215+
return true;
216+
}
217+
200218
/// Concatenate the given fragments to a single vector value of the type
201219
/// described in @p VS.
202220
static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
@@ -276,6 +294,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
276294
bool visitBitCastInst(BitCastInst &BCI);
277295
bool visitInsertElementInst(InsertElementInst &IEI);
278296
bool visitExtractElementInst(ExtractElementInst &EEI);
297+
bool visitExtractValueInst(ExtractValueInst &EVI);
279298
bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
280299
bool visitPHINode(PHINode &PHI);
281300
bool visitLoadInst(LoadInst &LI);
@@ -667,14 +686,26 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
667686
bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
668687
if (isTriviallyVectorizable(ID))
669688
return true;
689+
// TODO: Move frexp to isTriviallyVectorizable.
690+
// https://github.com/llvm/llvm-project/issues/112408
691+
switch (ID) {
692+
case Intrinsic::frexp:
693+
return true;
694+
}
670695
return Intrinsic::isTargetIntrinsic(ID) &&
671696
TTI->isTargetIntrinsicTriviallyScalarizable(ID);
672697
}
673698

674699
/// If a call to a vector typed intrinsic function, split into a scalar call per
675700
/// element if possible for the intrinsic.
676701
bool ScalarizerVisitor::splitCall(CallInst &CI) {
677-
std::optional<VectorSplit> VS = getVectorSplit(CI.getType());
702+
Type *CallType = CI.getType();
703+
bool AreAllVectorsOfMatchingSize = isStructOfMatchingFixedVectors(CallType);
704+
std::optional<VectorSplit> VS;
705+
if (AreAllVectorsOfMatchingSize)
706+
VS = getVectorSplit(CallType->getContainedType(0));
707+
else
708+
VS = getVectorSplit(CallType);
678709
if (!VS)
679710
return false;
680711

@@ -699,6 +730,23 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
699730
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
700731
Tys.push_back(VS->SplitTy);
701732

733+
if (AreAllVectorsOfMatchingSize) {
734+
for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
735+
std::optional<VectorSplit> CurrVS =
736+
getVectorSplit(cast<FixedVectorType>(CallType->getContainedType(I)));
737+
// This case does not seem to happen, but it is possible for
738+
// VectorSplit.NumPacked >= NumElems. If that happens a VectorSplit
739+
// is not returned and we will bailout of handling this call.
740+
// The secondary bailout case is if NumPacked does not match.
741+
// This can happen if ScalarizeMinBits is not set to the default.
742+
// This means with certain ScalarizeMinBits intrinsics like frexp
743+
// will only scalarize when the struct elements have the same bitness.
744+
if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
745+
return false;
746+
if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I))
747+
Tys.push_back(CurrVS->SplitTy);
748+
}
749+
}
702750
// Assumes that any vector type has the same number of elements as the return
703751
// vector type, which is true for all current intrinsics.
704752
for (unsigned I = 0; I != NumArgs; ++I) {
@@ -1030,6 +1078,31 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
10301078
return true;
10311079
}
10321080

1081+
bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
1082+
Value *Op = EVI.getOperand(0);
1083+
Type *OpTy = Op->getType();
1084+
ValueVector Res;
1085+
if (!isStructOfMatchingFixedVectors(OpTy))
1086+
return false;
1087+
Type *VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
1088+
std::optional<VectorSplit> VS = getVectorSplit(VecType);
1089+
if (!VS)
1090+
return false;
1091+
IRBuilder<> Builder(&EVI);
1092+
Scatterer Op0 = scatter(&EVI, Op, *VS);
1093+
assert(!EVI.getIndices().empty() && "Make sure an index exists");
1094+
// Note for our use case we only care about the top level index.
1095+
unsigned Index = EVI.getIndices()[0];
1096+
for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) {
1097+
Value *ResElem = Builder.CreateExtractValue(
1098+
Op0[OpIdx], Index, EVI.getName() + ".elem" + Twine(Index));
1099+
Res.push_back(ResElem);
1100+
}
1101+
1102+
gather(&EVI, Res, *VS);
1103+
return true;
1104+
}
1105+
10331106
bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
10341107
std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
10351108
if (!VS)
@@ -1209,6 +1282,35 @@ bool ScalarizerVisitor::finish() {
12091282
Res = concatenate(Builder, CV, VS, Op->getName());
12101283

12111284
Res->takeName(Op);
1285+
} else if (auto *Ty = dyn_cast<StructType>(Op->getType())) {
1286+
BasicBlock *BB = Op->getParent();
1287+
IRBuilder<> Builder(Op);
1288+
if (isa<PHINode>(Op))
1289+
Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
1290+
1291+
// Iterate over each element in the struct
1292+
unsigned NumOfStructElements = Ty->getNumElements();
1293+
SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
1294+
for (unsigned I = 0; I < NumOfStructElements; ++I) {
1295+
for (auto *CVelem : CV) {
1296+
Value *Elem = Builder.CreateExtractValue(
1297+
CVelem, I, Op->getName() + ".elem" + Twine(I));
1298+
ElemCV[I].push_back(Elem);
1299+
}
1300+
}
1301+
Res = PoisonValue::get(Ty);
1302+
for (unsigned I = 0; I < NumOfStructElements; ++I) {
1303+
Type *ElemTy = Ty->getElementType(I);
1304+
assert(isa<FixedVectorType>(ElemTy) &&
1305+
"Only Structs of all FixedVectorType supported");
1306+
VectorSplit VS = *getVectorSplit(ElemTy);
1307+
assert(VS.NumFragments == CV.size());
1308+
1309+
Value *ConcatenatedVector =
1310+
concatenate(Builder, ElemCV[I], VS, Op->getName());
1311+
Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I,
1312+
Op->getName() + ".insert");
1313+
}
12121314
} else {
12131315
assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
12141316
Res = CV[0];
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -passes='function(scalarizer)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
3+
4+
define void @test_vector_double_split_void(<2 x double> noundef %d) {
5+
; CHECK-LABEL: define void @test_vector_double_split_void(
6+
; CHECK-SAME: <2 x double> noundef [[D:%.*]]) {
7+
; CHECK-NEXT: [[D_I0:%.*]] = extractelement <2 x double> [[D]], i64 0
8+
; CHECK-NEXT: [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
9+
; CHECK-NEXT: [[D_I1:%.*]] = extractelement <2 x double> [[D]], i64 1
10+
; CHECK-NEXT: [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
11+
; CHECK-NEXT: ret void
12+
;
13+
%hlsl.asuint = call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> %d)
14+
ret void
15+
}
16+
17+
define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
18+
; CHECK-LABEL: define noundef <3 x i32> @test_vector_double_split(
19+
; CHECK-SAME: <3 x double> noundef [[D:%.*]]) {
20+
; CHECK-NEXT: [[D_I0:%.*]] = extractelement <3 x double> [[D]], i64 0
21+
; CHECK-NEXT: [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
22+
; CHECK-NEXT: [[D_I1:%.*]] = extractelement <3 x double> [[D]], i64 1
23+
; CHECK-NEXT: [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
24+
; CHECK-NEXT: [[D_I2:%.*]] = extractelement <3 x double> [[D]], i64 2
25+
; CHECK-NEXT: [[HLSL_ASUINT_I2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I2]])
26+
; CHECK-NEXT: [[DOTELEM0:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 0
27+
; CHECK-NEXT: [[DOTELEM01:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 0
28+
; CHECK-NEXT: [[DOTELEM02:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 0
29+
; CHECK-NEXT: [[DOTELEM1:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 1
30+
; CHECK-NEXT: [[DOTELEM13:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 1
31+
; CHECK-NEXT: [[DOTELEM14:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 1
32+
; CHECK-NEXT: [[DOTI0:%.*]] = add i32 [[DOTELEM0]], [[DOTELEM1]]
33+
; CHECK-NEXT: [[DOTI1:%.*]] = add i32 [[DOTELEM01]], [[DOTELEM13]]
34+
; CHECK-NEXT: [[DOTI2:%.*]] = add i32 [[DOTELEM02]], [[DOTELEM14]]
35+
; CHECK-NEXT: [[DOTUPTO015:%.*]] = insertelement <3 x i32> poison, i32 [[DOTI0]], i64 0
36+
; CHECK-NEXT: [[DOTUPTO116:%.*]] = insertelement <3 x i32> [[DOTUPTO015]], i32 [[DOTI1]], i64 1
37+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <3 x i32> [[DOTUPTO116]], i32 [[DOTI2]], i64 2
38+
; CHECK-NEXT: ret <3 x i32> [[TMP1]]
39+
;
40+
%hlsl.asuint = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %d)
41+
%1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
42+
%2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
43+
%3 = add <3 x i32> %1, %2
44+
ret <3 x i32> %3
45+
}

0 commit comments

Comments
 (0)