Skip to content

[Scalarizer][DirectX] support structs return types #111569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);

/// Identifies if the vector form of the intrinsic that returns a struct is
/// overloaded at the struct element index \p RetIdx.
bool isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx);

/// Returns intrinsic ID for call.
/// For the input call instruction it finds mapping intrinsic and returns
/// its intrinsic ID, in case it does not found it return not_intrinsic.
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,7 @@ def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrCon
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>],
[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
}
10 changes: 10 additions & 0 deletions llvm/lib/Analysis/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
}
}

bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) {
switch (ID) {
case Intrinsic::frexp:
return RetIdx == 0 || RetIdx == 1;
default:
return RetIdx == 0;
}
}

/// Returns intrinsic ID for call.
/// For the input call instruction it finds mapping intrinsic and returns
/// its ID, in case it does not found it return not_intrinsic.
Expand Down
77 changes: 39 additions & 38 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
@@ -1,38 +1,39 @@
//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
//===----------------------------------------------------------------------===//

#include "DirectXTargetTransformInfo.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"

using namespace llvm;

bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) {
switch (ID) {
case Intrinsic::dx_wave_readlane:
return ScalarOpdIdx == 1;
default:
return false;
}
}

bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
Intrinsic::ID ID) const {
switch (ID) {
case Intrinsic::dx_frac:
case Intrinsic::dx_rsqrt:
case Intrinsic::dx_wave_readlane:
return true;
default:
return false;
}
}
//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
//===----------------------------------------------------------------------===//

#include "DirectXTargetTransformInfo.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"

using namespace llvm;

bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) {
switch (ID) {
case Intrinsic::dx_wave_readlane:
return ScalarOpdIdx == 1;
default:
return false;
}
}

bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
Intrinsic::ID ID) const {
switch (ID) {
case Intrinsic::dx_frac:
case Intrinsic::dx_rsqrt:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_splitdouble:
return true;
default:
return false;
}
}
104 changes: 103 additions & 1 deletion llvm/lib/Transforms/Scalar/Scalarizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,24 @@ struct VectorLayout {
uint64_t SplitSize = 0;
};

static bool isStructOfMatchingFixedVectors(Type *Ty) {
if (!isa<StructType>(Ty))
return false;
unsigned StructSize = Ty->getNumContainedTypes();
if (StructSize < 1)
return false;
FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(0));
if (!VecTy)
return false;
unsigned VecSize = VecTy->getNumElements();
for (unsigned I = 1; I < StructSize; I++) {
VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(I));
if (!VecTy || VecSize != VecTy->getNumElements())
return false;
}
return true;
}

/// Concatenate the given fragments to a single vector value of the type
/// described in @p VS.
static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
Expand Down Expand Up @@ -276,6 +294,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
bool visitBitCastInst(BitCastInst &BCI);
bool visitInsertElementInst(InsertElementInst &IEI);
bool visitExtractElementInst(ExtractElementInst &EEI);
bool visitExtractValueInst(ExtractValueInst &EVI);
bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
bool visitPHINode(PHINode &PHI);
bool visitLoadInst(LoadInst &LI);
Expand Down Expand Up @@ -667,14 +686,26 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
if (isTriviallyVectorizable(ID))
return true;
// TODO: Move frexp to isTriviallyVectorizable.
// https://github.com/llvm/llvm-project/issues/112408
switch (ID) {
case Intrinsic::frexp:
return true;
}
return Intrinsic::isTargetIntrinsic(ID) &&
TTI->isTargetIntrinsicTriviallyScalarizable(ID);
}

/// If a call to a vector typed intrinsic function, split into a scalar call per
/// element if possible for the intrinsic.
bool ScalarizerVisitor::splitCall(CallInst &CI) {
std::optional<VectorSplit> VS = getVectorSplit(CI.getType());
Type *CallType = CI.getType();
bool AreAllVectorsOfMatchingSize = isStructOfMatchingFixedVectors(CallType);
std::optional<VectorSplit> VS;
if (AreAllVectorsOfMatchingSize)
VS = getVectorSplit(CallType->getContainedType(0));
else
VS = getVectorSplit(CallType);
if (!VS)
return false;

Expand All @@ -699,6 +730,23 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
Tys.push_back(VS->SplitTy);

if (AreAllVectorsOfMatchingSize) {
for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
std::optional<VectorSplit> CurrVS =
getVectorSplit(cast<FixedVectorType>(CallType->getContainedType(I)));
// This case does not seem to happen, but it is possible for
// VectorSplit.NumPacked >= NumElems. If that happens a VectorSplit
// is not returned and we will bailout of handling this call.
// The secondary bailout case is if NumPacked does not match.
// This can happen if ScalarizeMinBits is not set to the default.
// This means with certain ScalarizeMinBits intrinsics like frexp
// will only scalarize when the struct elements have the same bitness.
if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
return false;
if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I))
Tys.push_back(CurrVS->SplitTy);
}
}
// Assumes that any vector type has the same number of elements as the return
// vector type, which is true for all current intrinsics.
for (unsigned I = 0; I != NumArgs; ++I) {
Expand Down Expand Up @@ -1030,6 +1078,31 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
return true;
}

bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
Value *Op = EVI.getOperand(0);
Type *OpTy = Op->getType();
ValueVector Res;
if (!isStructOfMatchingFixedVectors(OpTy))
return false;
Type *VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
std::optional<VectorSplit> VS = getVectorSplit(VecType);
if (!VS)
return false;
IRBuilder<> Builder(&EVI);
Scatterer Op0 = scatter(&EVI, Op, *VS);
assert(!EVI.getIndices().empty() && "Make sure an index exists");
// Note for our use case we only care about the top level index.
unsigned Index = EVI.getIndices()[0];
for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) {
Value *ResElem = Builder.CreateExtractValue(
Op0[OpIdx], Index, EVI.getName() + ".elem" + Twine(Index));
Res.push_back(ResElem);
}

gather(&EVI, Res, *VS);
return true;
}

bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
if (!VS)
Expand Down Expand Up @@ -1209,6 +1282,35 @@ bool ScalarizerVisitor::finish() {
Res = concatenate(Builder, CV, VS, Op->getName());

Res->takeName(Op);
} else if (auto *Ty = dyn_cast<StructType>(Op->getType())) {
BasicBlock *BB = Op->getParent();
IRBuilder<> Builder(Op);
if (isa<PHINode>(Op))
Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());

// Iterate over each element in the struct
unsigned NumOfStructElements = Ty->getNumElements();
SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
for (unsigned I = 0; I < NumOfStructElements; ++I) {
for (auto *CVelem : CV) {
Value *Elem = Builder.CreateExtractValue(
CVelem, I, Op->getName() + ".elem" + Twine(I));
ElemCV[I].push_back(Elem);
}
}
Res = PoisonValue::get(Ty);
for (unsigned I = 0; I < NumOfStructElements; ++I) {
Type *ElemTy = Ty->getElementType(I);
assert(isa<FixedVectorType>(ElemTy) &&
"Only Structs of all FixedVectorType supported");
VectorSplit VS = *getVectorSplit(ElemTy);
assert(VS.NumFragments == CV.size());

Value *ConcatenatedVector =
concatenate(Builder, ElemCV[I], VS, Op->getName());
Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I,
Op->getName() + ".insert");
}
} else {
assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
Res = CV[0];
Expand Down
45 changes: 45 additions & 0 deletions llvm/test/CodeGen/DirectX/split-double.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes='function(scalarizer)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s

define void @test_vector_double_split_void(<2 x double> noundef %d) {
; CHECK-LABEL: define void @test_vector_double_split_void(
; CHECK-SAME: <2 x double> noundef [[D:%.*]]) {
; CHECK-NEXT: [[D_I0:%.*]] = extractelement <2 x double> [[D]], i64 0
; CHECK-NEXT: [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
; CHECK-NEXT: [[D_I1:%.*]] = extractelement <2 x double> [[D]], i64 1
; CHECK-NEXT: [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
; CHECK-NEXT: ret void
;
%hlsl.asuint = call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> %d)
ret void
}

define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
; CHECK-LABEL: define noundef <3 x i32> @test_vector_double_split(
; CHECK-SAME: <3 x double> noundef [[D:%.*]]) {
; CHECK-NEXT: [[D_I0:%.*]] = extractelement <3 x double> [[D]], i64 0
; CHECK-NEXT: [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
; CHECK-NEXT: [[D_I1:%.*]] = extractelement <3 x double> [[D]], i64 1
; CHECK-NEXT: [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
; CHECK-NEXT: [[D_I2:%.*]] = extractelement <3 x double> [[D]], i64 2
; CHECK-NEXT: [[HLSL_ASUINT_I2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I2]])
; CHECK-NEXT: [[DOTELEM0:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 0
; CHECK-NEXT: [[DOTELEM01:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 0
; CHECK-NEXT: [[DOTELEM02:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 0
; CHECK-NEXT: [[DOTELEM1:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 1
; CHECK-NEXT: [[DOTELEM13:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 1
; CHECK-NEXT: [[DOTELEM14:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 1
; CHECK-NEXT: [[DOTI0:%.*]] = add i32 [[DOTELEM0]], [[DOTELEM1]]
; CHECK-NEXT: [[DOTI1:%.*]] = add i32 [[DOTELEM01]], [[DOTELEM13]]
; CHECK-NEXT: [[DOTI2:%.*]] = add i32 [[DOTELEM02]], [[DOTELEM14]]
; CHECK-NEXT: [[DOTUPTO015:%.*]] = insertelement <3 x i32> poison, i32 [[DOTI0]], i64 0
; CHECK-NEXT: [[DOTUPTO116:%.*]] = insertelement <3 x i32> [[DOTUPTO015]], i32 [[DOTI1]], i64 1
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <3 x i32> [[DOTUPTO116]], i32 [[DOTI2]], i64 2
; CHECK-NEXT: ret <3 x i32> [[TMP1]]
;
%hlsl.asuint = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %d)
%1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
%2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
%3 = add <3 x i32> %1, %2
ret <3 x i32> %3
}
Loading
Loading