Skip to content

Commit c15557d

Browse files
committed
[CodeGen] Extend ComplexDeinterleaving pass to recognise patterns using integer types
AArch64 introduced CMLA and CADD instructions as part of SVE2. This change allows to generate such instructions when this architecture feature is available. Differential Revision: https://reviews.llvm.org/D153808
1 parent 98b0f13 commit c15557d

9 files changed

+910
-58
lines changed

llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp

Lines changed: 121 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@ static bool isInterleavingMask(ArrayRef<int> Mask);
100100
/// <1, 3, 5, 7>).
101101
static bool isDeinterleavingMask(ArrayRef<int> Mask);
102102

103+
/// Returns true if the operation is a negation of V, and it works for both
104+
/// integers and floats.
105+
static bool isNeg(Value *V);
106+
107+
/// Returns the operand for negation operation.
108+
static Value *getNegOperand(Value *V);
109+
103110
namespace {
104111

105112
class ComplexDeinterleavingLegacyPass : public FunctionPass {
@@ -146,7 +153,7 @@ struct ComplexDeinterleavingCompositeNode {
146153
// This two members are required exclusively for generating
147154
// ComplexDeinterleavingOperation::Symmetric operations.
148155
unsigned Opcode;
149-
FastMathFlags Flags;
156+
std::optional<FastMathFlags> Flags;
150157

151158
ComplexDeinterleavingRotation Rotation =
152159
ComplexDeinterleavingRotation::Rotation_0;
@@ -333,7 +340,8 @@ class ComplexDeinterleavingGraph {
333340
/// Return nullptr if it is not possible to construct a complex number.
334341
/// \p Flags are needed to generate symmetric Add and Sub operations.
335342
NodePtr identifyAdditions(std::list<Addend> &RealAddends,
336-
std::list<Addend> &ImagAddends, FastMathFlags Flags,
343+
std::list<Addend> &ImagAddends,
344+
std::optional<FastMathFlags> Flags,
337345
NodePtr Accumulator);
338346

339347
/// Extract one addend that have both real and imaginary parts positive.
@@ -512,6 +520,19 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) {
512520
return true;
513521
}
514522

523+
bool isNeg(Value *V) {
524+
return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
525+
}
526+
527+
Value *getNegOperand(Value *V) {
528+
assert(isNeg(V));
529+
auto *I = cast<Instruction>(V);
530+
if (I->getOpcode() == Instruction::FNeg)
531+
return I->getOperand(0);
532+
533+
return I->getOperand(1);
534+
}
535+
515536
bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
516537
ComplexDeinterleavingGraph Graph(TL, TLI);
517538
if (Graph.collectPotentialReductions(B))
@@ -540,9 +561,12 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
540561
return nullptr;
541562
}
542563

543-
if (Real->getOpcode() != Instruction::FMul ||
544-
Imag->getOpcode() != Instruction::FMul) {
545-
LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n");
564+
if ((Real->getOpcode() != Instruction::FMul &&
565+
Real->getOpcode() != Instruction::Mul) ||
566+
(Imag->getOpcode() != Instruction::FMul &&
567+
Imag->getOpcode() != Instruction::Mul)) {
568+
LLVM_DEBUG(
569+
dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
546570
return nullptr;
547571
}
548572

@@ -563,7 +587,7 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
563587
R1 = Op;
564588
}
565589

566-
if (match(I0, m_Neg(m_Value(Op)))) {
590+
if (isNeg(I0)) {
567591
Negs |= 2;
568592
Negs ^= 1;
569593
I0 = Op;
@@ -634,26 +658,29 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
634658
LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
635659
<< "\n");
636660
// Determine rotation
661+
auto IsAdd = [](unsigned Op) {
662+
return Op == Instruction::FAdd || Op == Instruction::Add;
663+
};
664+
auto IsSub = [](unsigned Op) {
665+
return Op == Instruction::FSub || Op == Instruction::Sub;
666+
};
637667
ComplexDeinterleavingRotation Rotation;
638-
if (Real->getOpcode() == Instruction::FAdd &&
639-
Imag->getOpcode() == Instruction::FAdd)
668+
if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
640669
Rotation = ComplexDeinterleavingRotation::Rotation_0;
641-
else if (Real->getOpcode() == Instruction::FSub &&
642-
Imag->getOpcode() == Instruction::FAdd)
670+
else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
643671
Rotation = ComplexDeinterleavingRotation::Rotation_90;
644-
else if (Real->getOpcode() == Instruction::FSub &&
645-
Imag->getOpcode() == Instruction::FSub)
672+
else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
646673
Rotation = ComplexDeinterleavingRotation::Rotation_180;
647-
else if (Real->getOpcode() == Instruction::FAdd &&
648-
Imag->getOpcode() == Instruction::FSub)
674+
else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
649675
Rotation = ComplexDeinterleavingRotation::Rotation_270;
650676
else {
651677
LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
652678
return nullptr;
653679
}
654680

655-
if (!Real->getFastMathFlags().allowContract() ||
656-
!Imag->getFastMathFlags().allowContract()) {
681+
if (isa<FPMathOperator>(Real) &&
682+
(!Real->getFastMathFlags().allowContract() ||
683+
!Imag->getFastMathFlags().allowContract())) {
657684
LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
658685
return nullptr;
659686
}
@@ -816,6 +843,9 @@ static bool isInstructionPotentiallySymmetric(Instruction *I) {
816843
case Instruction::FSub:
817844
case Instruction::FMul:
818845
case Instruction::FNeg:
846+
case Instruction::Add:
847+
case Instruction::Sub:
848+
case Instruction::Mul:
819849
return true;
820850
default:
821851
return false;
@@ -925,27 +955,31 @@ ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
925955
ComplexDeinterleavingGraph::NodePtr
926956
ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
927957
Instruction *Imag) {
958+
auto IsOperationSupported = [](unsigned Opcode) -> bool {
959+
return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
960+
Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
961+
Opcode == Instruction::Sub;
962+
};
928963

929-
if ((Real->getOpcode() != Instruction::FAdd &&
930-
Real->getOpcode() != Instruction::FSub &&
931-
Real->getOpcode() != Instruction::FNeg) ||
932-
(Imag->getOpcode() != Instruction::FAdd &&
933-
Imag->getOpcode() != Instruction::FSub &&
934-
Imag->getOpcode() != Instruction::FNeg))
964+
if (!IsOperationSupported(Real->getOpcode()) ||
965+
!IsOperationSupported(Imag->getOpcode()))
935966
return nullptr;
936967

937-
if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
938-
LLVM_DEBUG(
939-
dbgs()
940-
<< "The flags in Real and Imaginary instructions are not identical\n");
941-
return nullptr;
942-
}
968+
std::optional<FastMathFlags> Flags;
969+
if (isa<FPMathOperator>(Real)) {
970+
if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
971+
LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
972+
"not identical\n");
973+
return nullptr;
974+
}
943975

944-
FastMathFlags Flags = Real->getFastMathFlags();
945-
if (!Flags.allowReassoc()) {
946-
LLVM_DEBUG(
947-
dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n");
948-
return nullptr;
976+
Flags = Real->getFastMathFlags();
977+
if (!Flags->allowReassoc()) {
978+
LLVM_DEBUG(
979+
dbgs()
980+
<< "the 'Reassoc' attribute is missing in the FastMath flags\n");
981+
return nullptr;
982+
}
949983
}
950984

951985
// Collect multiplications and addend instructions from the given instruction
@@ -978,35 +1012,52 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
9781012
Addends.emplace_back(I, IsPositive);
9791013
continue;
9801014
}
981-
982-
if (I->getOpcode() == Instruction::FAdd) {
1015+
switch (I->getOpcode()) {
1016+
case Instruction::FAdd:
1017+
case Instruction::Add:
9831018
Worklist.emplace_back(I->getOperand(1), IsPositive);
9841019
Worklist.emplace_back(I->getOperand(0), IsPositive);
985-
} else if (I->getOpcode() == Instruction::FSub) {
1020+
break;
1021+
case Instruction::FSub:
9861022
Worklist.emplace_back(I->getOperand(1), !IsPositive);
9871023
Worklist.emplace_back(I->getOperand(0), IsPositive);
988-
} else if (I->getOpcode() == Instruction::FMul) {
1024+
break;
1025+
case Instruction::Sub:
1026+
if (isNeg(I)) {
1027+
Worklist.emplace_back(getNegOperand(I), !IsPositive);
1028+
} else {
1029+
Worklist.emplace_back(I->getOperand(1), !IsPositive);
1030+
Worklist.emplace_back(I->getOperand(0), IsPositive);
1031+
}
1032+
break;
1033+
case Instruction::FMul:
1034+
case Instruction::Mul: {
9891035
Value *A, *B;
990-
if (match(I->getOperand(0), m_FNeg(m_Value(A)))) {
1036+
if (isNeg(I->getOperand(0))) {
1037+
A = getNegOperand(I->getOperand(0));
9911038
IsPositive = !IsPositive;
9921039
} else {
9931040
A = I->getOperand(0);
9941041
}
9951042

996-
if (match(I->getOperand(1), m_FNeg(m_Value(B)))) {
1043+
if (isNeg(I->getOperand(1))) {
1044+
B = getNegOperand(I->getOperand(1));
9971045
IsPositive = !IsPositive;
9981046
} else {
9991047
B = I->getOperand(1);
10001048
}
10011049
Muls.push_back(Product{A, B, IsPositive});
1002-
} else if (I->getOpcode() == Instruction::FNeg) {
1050+
break;
1051+
}
1052+
case Instruction::FNeg:
10031053
Worklist.emplace_back(I->getOperand(0), !IsPositive);
1004-
} else {
1054+
break;
1055+
default:
10051056
Addends.emplace_back(I, IsPositive);
10061057
continue;
10071058
}
10081059

1009-
if (I->getFastMathFlags() != Flags) {
1060+
if (Flags && I->getFastMathFlags() != *Flags) {
10101061
LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
10111062
"inconsistent with the root instructions' flags: "
10121063
<< *I << "\n");
@@ -1258,10 +1309,9 @@ ComplexDeinterleavingGraph::identifyMultiplications(
12581309
}
12591310

12601311
ComplexDeinterleavingGraph::NodePtr
1261-
ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
1262-
std::list<Addend> &ImagAddends,
1263-
FastMathFlags Flags,
1264-
NodePtr Accumulator = nullptr) {
1312+
ComplexDeinterleavingGraph::identifyAdditions(
1313+
std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1314+
std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
12651315
if (RealAddends.size() != ImagAddends.size())
12661316
return nullptr;
12671317

@@ -1312,14 +1362,22 @@ ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
13121362
if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
13131363
TmpNode = prepareCompositeNode(
13141364
ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1315-
TmpNode->Opcode = Instruction::FAdd;
1316-
TmpNode->Flags = Flags;
1365+
if (Flags) {
1366+
TmpNode->Opcode = Instruction::FAdd;
1367+
TmpNode->Flags = *Flags;
1368+
} else {
1369+
TmpNode->Opcode = Instruction::Add;
1370+
}
13171371
} else if (Rotation ==
13181372
llvm::ComplexDeinterleavingRotation::Rotation_180) {
13191373
TmpNode = prepareCompositeNode(
13201374
ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1321-
TmpNode->Opcode = Instruction::FSub;
1322-
TmpNode->Flags = Flags;
1375+
if (Flags) {
1376+
TmpNode->Opcode = Instruction::FSub;
1377+
TmpNode->Flags = *Flags;
1378+
} else {
1379+
TmpNode->Opcode = Instruction::Sub;
1380+
}
13231381
} else {
13241382
TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
13251383
nullptr, nullptr);
@@ -1815,8 +1873,8 @@ ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
18151873
}
18161874

18171875
static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1818-
FastMathFlags Flags, Value *InputA,
1819-
Value *InputB) {
1876+
std::optional<FastMathFlags> Flags,
1877+
Value *InputA, Value *InputB) {
18201878
Value *I;
18211879
switch (Opcode) {
18221880
case Instruction::FNeg:
@@ -1825,16 +1883,26 @@ static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
18251883
case Instruction::FAdd:
18261884
I = B.CreateFAdd(InputA, InputB);
18271885
break;
1886+
case Instruction::Add:
1887+
I = B.CreateAdd(InputA, InputB);
1888+
break;
18281889
case Instruction::FSub:
18291890
I = B.CreateFSub(InputA, InputB);
18301891
break;
1892+
case Instruction::Sub:
1893+
I = B.CreateSub(InputA, InputB);
1894+
break;
18311895
case Instruction::FMul:
18321896
I = B.CreateFMul(InputA, InputB);
18331897
break;
1898+
case Instruction::Mul:
1899+
I = B.CreateMul(InputA, InputB);
1900+
break;
18341901
default:
18351902
llvm_unreachable("Incorrect symmetric opcode");
18361903
}
1837-
cast<Instruction>(I)->setFastMathFlags(Flags);
1904+
if (Flags)
1905+
cast<Instruction>(I)->setFastMathFlags(*Flags);
18381906
return I;
18391907
}
18401908

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25858,7 +25858,8 @@ bool AArch64TargetLowering::isConstantUnsignedBitfieldExtractLegal(
2585825858
}
2585925859

2586025860
bool AArch64TargetLowering::isComplexDeinterleavingSupported() const {
25861-
return Subtarget->hasSVE() || Subtarget->hasComplxNum();
25861+
return Subtarget->hasSVE() || Subtarget->hasSVE2() ||
25862+
Subtarget->hasComplxNum();
2586225863
}
2586325864

2586425865
bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
@@ -25884,6 +25885,11 @@ bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
2588425885
!llvm::isPowerOf2_32(VTyWidth))
2588525886
return false;
2588625887

25888+
if (ScalarTy->isIntegerTy() && Subtarget->hasSVE2()) {
25889+
unsigned ScalarWidth = ScalarTy->getScalarSizeInBits();
25890+
return 8 <= ScalarWidth && ScalarWidth <= 64;
25891+
}
25892+
2588725893
return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) ||
2588825894
ScalarTy->isFloatTy() || ScalarTy->isDoubleTy();
2588925895
}
@@ -25894,6 +25900,7 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
2589425900
Value *Accumulator) const {
2589525901
VectorType *Ty = cast<VectorType>(InputA->getType());
2589625902
bool IsScalable = Ty->isScalableTy();
25903+
bool IsInt = Ty->getElementType()->isIntegerTy();
2589725904

2589825905
unsigned TyWidth =
2589925906
Ty->getScalarSizeInBits() * Ty->getElementCount().getKnownMinValue();
@@ -25929,10 +25936,15 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
2592925936

2593025937
if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
2593125938
if (Accumulator == nullptr)
25932-
Accumulator = ConstantFP::get(Ty, 0);
25939+
Accumulator = Constant::getNullValue(Ty);
2593325940

2593425941
if (IsScalable) {
25935-
auto *Mask = B.CreateVectorSplat(Ty->getElementCount(), B.getInt1(true));
25942+
if (IsInt)
25943+
return B.CreateIntrinsic(
25944+
Intrinsic::aarch64_sve_cmla_x, Ty,
25945+
{Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});
25946+
25947+
auto *Mask = B.getAllOnesMask(Ty->getElementCount());
2593625948
return B.CreateIntrinsic(
2593725949
Intrinsic::aarch64_sve_fcmla, Ty,
2593825950
{Mask, Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});
@@ -25950,12 +25962,18 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
2595025962

2595125963
if (OperationType == ComplexDeinterleavingOperation::CAdd) {
2595225964
if (IsScalable) {
25953-
auto *Mask = B.CreateVectorSplat(Ty->getElementCount(), B.getInt1(true));
2595425965
if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
25955-
Rotation == ComplexDeinterleavingRotation::Rotation_270)
25966+
Rotation == ComplexDeinterleavingRotation::Rotation_270) {
25967+
if (IsInt)
25968+
return B.CreateIntrinsic(
25969+
Intrinsic::aarch64_sve_cadd_x, Ty,
25970+
{InputA, InputB, B.getInt32((int)Rotation * 90)});
25971+
25972+
auto *Mask = B.getAllOnesMask(Ty->getElementCount());
2595625973
return B.CreateIntrinsic(
2595725974
Intrinsic::aarch64_sve_fcadd, Ty,
2595825975
{Mask, InputA, InputB, B.getInt32((int)Rotation * 90)});
25976+
}
2595925977
return nullptr;
2596025978
}
2596125979

0 commit comments

Comments
 (0)