@@ -100,6 +100,13 @@ static bool isInterleavingMask(ArrayRef<int> Mask);
100
100
// / <1, 3, 5, 7>).
101
101
static bool isDeinterleavingMask (ArrayRef<int > Mask);
102
102
103
+ // / Returns true if the operation is a negation of V, and it works for both
104
+ // / integers and floats.
105
+ static bool isNeg (Value *V);
106
+
107
+ // / Returns the operand for negation operation.
108
+ static Value *getNegOperand (Value *V);
109
+
103
110
namespace {
104
111
105
112
class ComplexDeinterleavingLegacyPass : public FunctionPass {
@@ -146,7 +153,7 @@ struct ComplexDeinterleavingCompositeNode {
146
153
// This two members are required exclusively for generating
147
154
// ComplexDeinterleavingOperation::Symmetric operations.
148
155
unsigned Opcode;
149
- FastMathFlags Flags;
156
+ std::optional< FastMathFlags> Flags;
150
157
151
158
ComplexDeinterleavingRotation Rotation =
152
159
ComplexDeinterleavingRotation::Rotation_0;
@@ -333,7 +340,8 @@ class ComplexDeinterleavingGraph {
333
340
// / Return nullptr if it is not possible to construct a complex number.
334
341
// / \p Flags are needed to generate symmetric Add and Sub operations.
335
342
NodePtr identifyAdditions (std::list<Addend> &RealAddends,
336
- std::list<Addend> &ImagAddends, FastMathFlags Flags,
343
+ std::list<Addend> &ImagAddends,
344
+ std::optional<FastMathFlags> Flags,
337
345
NodePtr Accumulator);
338
346
339
347
// / Extract one addend that have both real and imaginary parts positive.
@@ -512,6 +520,19 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) {
512
520
return true ;
513
521
}
514
522
523
+ bool isNeg (Value *V) {
524
+ return match (V, m_FNeg (m_Value ())) || match (V, m_Neg (m_Value ()));
525
+ }
526
+
527
+ Value *getNegOperand (Value *V) {
528
+ assert (isNeg (V));
529
+ auto *I = cast<Instruction>(V);
530
+ if (I->getOpcode () == Instruction::FNeg)
531
+ return I->getOperand (0 );
532
+
533
+ return I->getOperand (1 );
534
+ }
535
+
515
536
bool ComplexDeinterleaving::evaluateBasicBlock (BasicBlock *B) {
516
537
ComplexDeinterleavingGraph Graph (TL, TLI);
517
538
if (Graph.collectPotentialReductions (B))
@@ -540,9 +561,12 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
540
561
return nullptr ;
541
562
}
542
563
543
- if (Real->getOpcode () != Instruction::FMul ||
544
- Imag->getOpcode () != Instruction::FMul) {
545
- LLVM_DEBUG (dbgs () << " - Real or imaginary instruction is not fmul\n " );
564
+ if ((Real->getOpcode () != Instruction::FMul &&
565
+ Real->getOpcode () != Instruction::Mul) ||
566
+ (Imag->getOpcode () != Instruction::FMul &&
567
+ Imag->getOpcode () != Instruction::Mul)) {
568
+ LLVM_DEBUG (
569
+ dbgs () << " - Real or imaginary instruction is not fmul or mul\n " );
546
570
return nullptr ;
547
571
}
548
572
@@ -563,7 +587,7 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
563
587
R1 = Op;
564
588
}
565
589
566
- if (match (I0, m_Neg ( m_Value (Op)) )) {
590
+ if (isNeg (I0)) {
567
591
Negs |= 2 ;
568
592
Negs ^= 1 ;
569
593
I0 = Op;
@@ -634,26 +658,29 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
634
658
LLVM_DEBUG (dbgs () << " identifyPartialMul " << *Real << " / " << *Imag
635
659
<< " \n " );
636
660
// Determine rotation
661
+ auto IsAdd = [](unsigned Op) {
662
+ return Op == Instruction::FAdd || Op == Instruction::Add;
663
+ };
664
+ auto IsSub = [](unsigned Op) {
665
+ return Op == Instruction::FSub || Op == Instruction::Sub;
666
+ };
637
667
ComplexDeinterleavingRotation Rotation;
638
- if (Real->getOpcode () == Instruction::FAdd &&
639
- Imag->getOpcode () == Instruction::FAdd)
668
+ if (IsAdd (Real->getOpcode ()) && IsAdd (Imag->getOpcode ()))
640
669
Rotation = ComplexDeinterleavingRotation::Rotation_0;
641
- else if (Real->getOpcode () == Instruction::FSub &&
642
- Imag->getOpcode () == Instruction::FAdd)
670
+ else if (IsSub (Real->getOpcode ()) && IsAdd (Imag->getOpcode ()))
643
671
Rotation = ComplexDeinterleavingRotation::Rotation_90;
644
- else if (Real->getOpcode () == Instruction::FSub &&
645
- Imag->getOpcode () == Instruction::FSub)
672
+ else if (IsSub (Real->getOpcode ()) && IsSub (Imag->getOpcode ()))
646
673
Rotation = ComplexDeinterleavingRotation::Rotation_180;
647
- else if (Real->getOpcode () == Instruction::FAdd &&
648
- Imag->getOpcode () == Instruction::FSub)
674
+ else if (IsAdd (Real->getOpcode ()) && IsSub (Imag->getOpcode ()))
649
675
Rotation = ComplexDeinterleavingRotation::Rotation_270;
650
676
else {
651
677
LLVM_DEBUG (dbgs () << " - Unhandled rotation.\n " );
652
678
return nullptr ;
653
679
}
654
680
655
- if (!Real->getFastMathFlags ().allowContract () ||
656
- !Imag->getFastMathFlags ().allowContract ()) {
681
+ if (isa<FPMathOperator>(Real) &&
682
+ (!Real->getFastMathFlags ().allowContract () ||
683
+ !Imag->getFastMathFlags ().allowContract ())) {
657
684
LLVM_DEBUG (dbgs () << " - Contract is missing from the FastMath flags.\n " );
658
685
return nullptr ;
659
686
}
@@ -816,6 +843,9 @@ static bool isInstructionPotentiallySymmetric(Instruction *I) {
816
843
case Instruction::FSub:
817
844
case Instruction::FMul:
818
845
case Instruction::FNeg:
846
+ case Instruction::Add:
847
+ case Instruction::Sub:
848
+ case Instruction::Mul:
819
849
return true ;
820
850
default :
821
851
return false ;
@@ -925,27 +955,31 @@ ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
925
955
ComplexDeinterleavingGraph::NodePtr
926
956
ComplexDeinterleavingGraph::identifyReassocNodes (Instruction *Real,
927
957
Instruction *Imag) {
958
+ auto IsOperationSupported = [](unsigned Opcode) -> bool {
959
+ return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
960
+ Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
961
+ Opcode == Instruction::Sub;
962
+ };
928
963
929
- if ((Real->getOpcode () != Instruction::FAdd &&
930
- Real->getOpcode () != Instruction::FSub &&
931
- Real->getOpcode () != Instruction::FNeg) ||
932
- (Imag->getOpcode () != Instruction::FAdd &&
933
- Imag->getOpcode () != Instruction::FSub &&
934
- Imag->getOpcode () != Instruction::FNeg))
964
+ if (!IsOperationSupported (Real->getOpcode ()) ||
965
+ !IsOperationSupported (Imag->getOpcode ()))
935
966
return nullptr ;
936
967
937
- if (Real->getFastMathFlags () != Imag->getFastMathFlags ()) {
938
- LLVM_DEBUG (
939
- dbgs ()
940
- << " The flags in Real and Imaginary instructions are not identical\n " );
941
- return nullptr ;
942
- }
968
+ std::optional<FastMathFlags> Flags;
969
+ if (isa<FPMathOperator>(Real)) {
970
+ if (Real->getFastMathFlags () != Imag->getFastMathFlags ()) {
971
+ LLVM_DEBUG (dbgs () << " The flags in Real and Imaginary instructions are "
972
+ " not identical\n " );
973
+ return nullptr ;
974
+ }
943
975
944
- FastMathFlags Flags = Real->getFastMathFlags ();
945
- if (!Flags.allowReassoc ()) {
946
- LLVM_DEBUG (
947
- dbgs () << " the 'Reassoc' attribute is missing in the FastMath flags\n " );
948
- return nullptr ;
976
+ Flags = Real->getFastMathFlags ();
977
+ if (!Flags->allowReassoc ()) {
978
+ LLVM_DEBUG (
979
+ dbgs ()
980
+ << " the 'Reassoc' attribute is missing in the FastMath flags\n " );
981
+ return nullptr ;
982
+ }
949
983
}
950
984
951
985
// Collect multiplications and addend instructions from the given instruction
@@ -978,35 +1012,52 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
978
1012
Addends.emplace_back (I, IsPositive);
979
1013
continue ;
980
1014
}
981
-
982
- if (I->getOpcode () == Instruction::FAdd) {
1015
+ switch (I->getOpcode ()) {
1016
+ case Instruction::FAdd:
1017
+ case Instruction::Add:
983
1018
Worklist.emplace_back (I->getOperand (1 ), IsPositive);
984
1019
Worklist.emplace_back (I->getOperand (0 ), IsPositive);
985
- } else if (I->getOpcode () == Instruction::FSub) {
1020
+ break ;
1021
+ case Instruction::FSub:
986
1022
Worklist.emplace_back (I->getOperand (1 ), !IsPositive);
987
1023
Worklist.emplace_back (I->getOperand (0 ), IsPositive);
988
- } else if (I->getOpcode () == Instruction::FMul) {
1024
+ break ;
1025
+ case Instruction::Sub:
1026
+ if (isNeg (I)) {
1027
+ Worklist.emplace_back (getNegOperand (I), !IsPositive);
1028
+ } else {
1029
+ Worklist.emplace_back (I->getOperand (1 ), !IsPositive);
1030
+ Worklist.emplace_back (I->getOperand (0 ), IsPositive);
1031
+ }
1032
+ break ;
1033
+ case Instruction::FMul:
1034
+ case Instruction::Mul: {
989
1035
Value *A, *B;
990
- if (match (I->getOperand (0 ), m_FNeg (m_Value (A)))) {
1036
+ if (isNeg (I->getOperand (0 ))) {
1037
+ A = getNegOperand (I->getOperand (0 ));
991
1038
IsPositive = !IsPositive;
992
1039
} else {
993
1040
A = I->getOperand (0 );
994
1041
}
995
1042
996
- if (match (I->getOperand (1 ), m_FNeg (m_Value (B)))) {
1043
+ if (isNeg (I->getOperand (1 ))) {
1044
+ B = getNegOperand (I->getOperand (1 ));
997
1045
IsPositive = !IsPositive;
998
1046
} else {
999
1047
B = I->getOperand (1 );
1000
1048
}
1001
1049
Muls.push_back (Product{A, B, IsPositive});
1002
- } else if (I->getOpcode () == Instruction::FNeg) {
1050
+ break ;
1051
+ }
1052
+ case Instruction::FNeg:
1003
1053
Worklist.emplace_back (I->getOperand (0 ), !IsPositive);
1004
- } else {
1054
+ break ;
1055
+ default :
1005
1056
Addends.emplace_back (I, IsPositive);
1006
1057
continue ;
1007
1058
}
1008
1059
1009
- if (I->getFastMathFlags () != Flags) {
1060
+ if (Flags && I->getFastMathFlags () != * Flags) {
1010
1061
LLVM_DEBUG (dbgs () << " The instruction's fast math flags are "
1011
1062
" inconsistent with the root instructions' flags: "
1012
1063
<< *I << " \n " );
@@ -1258,10 +1309,9 @@ ComplexDeinterleavingGraph::identifyMultiplications(
1258
1309
}
1259
1310
1260
1311
ComplexDeinterleavingGraph::NodePtr
1261
- ComplexDeinterleavingGraph::identifyAdditions (std::list<Addend> &RealAddends,
1262
- std::list<Addend> &ImagAddends,
1263
- FastMathFlags Flags,
1264
- NodePtr Accumulator = nullptr ) {
1312
+ ComplexDeinterleavingGraph::identifyAdditions (
1313
+ std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1314
+ std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr ) {
1265
1315
if (RealAddends.size () != ImagAddends.size ())
1266
1316
return nullptr ;
1267
1317
@@ -1312,14 +1362,22 @@ ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
1312
1362
if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1313
1363
TmpNode = prepareCompositeNode (
1314
1364
ComplexDeinterleavingOperation::Symmetric, nullptr , nullptr );
1315
- TmpNode->Opcode = Instruction::FAdd;
1316
- TmpNode->Flags = Flags;
1365
+ if (Flags) {
1366
+ TmpNode->Opcode = Instruction::FAdd;
1367
+ TmpNode->Flags = *Flags;
1368
+ } else {
1369
+ TmpNode->Opcode = Instruction::Add;
1370
+ }
1317
1371
} else if (Rotation ==
1318
1372
llvm::ComplexDeinterleavingRotation::Rotation_180) {
1319
1373
TmpNode = prepareCompositeNode (
1320
1374
ComplexDeinterleavingOperation::Symmetric, nullptr , nullptr );
1321
- TmpNode->Opcode = Instruction::FSub;
1322
- TmpNode->Flags = Flags;
1375
+ if (Flags) {
1376
+ TmpNode->Opcode = Instruction::FSub;
1377
+ TmpNode->Flags = *Flags;
1378
+ } else {
1379
+ TmpNode->Opcode = Instruction::Sub;
1380
+ }
1323
1381
} else {
1324
1382
TmpNode = prepareCompositeNode (ComplexDeinterleavingOperation::CAdd,
1325
1383
nullptr , nullptr );
@@ -1815,8 +1873,8 @@ ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1815
1873
}
1816
1874
1817
1875
static Value *replaceSymmetricNode (IRBuilderBase &B, unsigned Opcode,
1818
- FastMathFlags Flags, Value *InputA ,
1819
- Value *InputB) {
1876
+ std::optional< FastMathFlags> Flags,
1877
+ Value *InputA, Value * InputB) {
1820
1878
Value *I;
1821
1879
switch (Opcode) {
1822
1880
case Instruction::FNeg:
@@ -1825,16 +1883,26 @@ static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1825
1883
case Instruction::FAdd:
1826
1884
I = B.CreateFAdd (InputA, InputB);
1827
1885
break ;
1886
+ case Instruction::Add:
1887
+ I = B.CreateAdd (InputA, InputB);
1888
+ break ;
1828
1889
case Instruction::FSub:
1829
1890
I = B.CreateFSub (InputA, InputB);
1830
1891
break ;
1892
+ case Instruction::Sub:
1893
+ I = B.CreateSub (InputA, InputB);
1894
+ break ;
1831
1895
case Instruction::FMul:
1832
1896
I = B.CreateFMul (InputA, InputB);
1833
1897
break ;
1898
+ case Instruction::Mul:
1899
+ I = B.CreateMul (InputA, InputB);
1900
+ break ;
1834
1901
default :
1835
1902
llvm_unreachable (" Incorrect symmetric opcode" );
1836
1903
}
1837
- cast<Instruction>(I)->setFastMathFlags (Flags);
1904
+ if (Flags)
1905
+ cast<Instruction>(I)->setFastMathFlags (*Flags);
1838
1906
return I;
1839
1907
}
1840
1908
0 commit comments