-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[MacroFusion] Support commutable instructions #82751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1b286ff
to
9d9a50d
Compare
9d9a50d
to
048d1bd
Compare
048d1bd
to
74c60af
Compare
@llvm/pr-subscribers-backend-risc-v Author: Wang Pengcheng (wangpc-pp) ChangesIf the second instruction is commutable, we should be able to check A simple RISCV fusion is contained in this PR to show the functionality There are some other issues I should fix. For example, we should be Fixes #82738 Full diff: https://github.com/llvm/llvm-project/pull/82751.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Target/TargetSchedule.td b/llvm/include/llvm/Target/TargetSchedule.td
index 48c9387977c075..a872cc30a9b50c 100644
--- a/llvm/include/llvm/Target/TargetSchedule.td
+++ b/llvm/include/llvm/Target/TargetSchedule.td
@@ -617,16 +617,27 @@ class SecondFusionPredicateWithMCInstPredicate<MCInstPredicate pred>
: FusionPredicateWithMCInstPredicate<second_fusion_target, pred>;
// The pred will be applied on both firstMI and secondMI.
class BothFusionPredicateWithMCInstPredicate<MCInstPredicate pred>
- : FusionPredicateWithMCInstPredicate<second_fusion_target, pred>;
+ : FusionPredicateWithMCInstPredicate<both_fusion_target, pred>;
// Tie firstOpIdx and secondOpIdx. The operand of `FirstMI` at position
// `firstOpIdx` should be the same as the operand of `SecondMI` at position
// `secondOpIdx`.
+// If the operand at `secondOpIdx` has commutable operand, then the commutable
+// operand will be checked too.
class TieReg<int firstOpIdx, int secondOpIdx> : BothFusionPredicate {
int FirstOpIdx = firstOpIdx;
int SecondOpIdx = secondOpIdx;
}
+// The operand of `SecondMI` at position `firstOpIdx` should be the same as the
+// operand at position `secondOpIdx`.
+// If the operand at `secondOpIdx` has commutable operand, then the commutable
+// operand will be checked too.
+class SameReg<int firstOpIdx, int secondOpIdx> : SecondFusionPredicate {
+ int FirstOpIdx = firstOpIdx;
+ int SecondOpIdx = secondOpIdx;
+}
+
// A predicate for wildcard. The generated code will be like:
// ```
// if (!FirstMI)
@@ -688,11 +699,7 @@ class SimpleFusion<string name, string fieldName, string desc,
SecondFusionPredicateWithMCInstPredicate<secondPred>,
WildcardTrue,
FirstFusionPredicateWithMCInstPredicate<firstPred>,
- SecondFusionPredicateWithMCInstPredicate<
- CheckAny<[
- CheckIsVRegOperand<0>,
- CheckSameRegOperand<0, 1>
- ]>>,
+ SameReg<0, 1>,
OneUse,
TieReg<0, 1>,
],
diff --git a/llvm/lib/Target/RISCV/RISCVMacroFusion.td b/llvm/lib/Target/RISCV/RISCVMacroFusion.td
index 875a93d09a2c64..14e8962f8ce110 100644
--- a/llvm/lib/Target/RISCV/RISCVMacroFusion.td
+++ b/llvm/lib/Target/RISCV/RISCVMacroFusion.td
@@ -91,3 +91,24 @@ def TuneLDADDFusion
CheckIsImmOperand<2>,
CheckImmOperand<2, 0>
]>>;
+
+// These should be covered by Zba extension.
+// * shift left by one and add:
+// slli r1, r0, 1
+// add r1, r1, r2
+// * shift left by two and add:
+// slli r1, r0, 2
+// add r1, r1, r2
+// * shift left by three and add:
+// slli r1, r0, 3
+// add r1, r1, r2
+def ShiftNAddFusion
+ : SimpleFusion<"shift-n-add-fusion", "HasShiftNAddFusion",
+ "Enable SLLI+ADD to be fused to shift left by 1/2/3 and add",
+ CheckAll<[
+ CheckOpcode<[SLLI]>,
+ CheckAny<[CheckImmOperand<2, 1>,
+ CheckImmOperand<2, 2>,
+ CheckImmOperand<2, 3>]>
+ ]>,
+ CheckOpcode<[ADD]>>;
diff --git a/llvm/test/CodeGen/RISCV/macro-fusions.mir b/llvm/test/CodeGen/RISCV/macro-fusions.mir
index 13464141ce27e6..11bb456e0ca050 100644
--- a/llvm/test/CodeGen/RISCV/macro-fusions.mir
+++ b/llvm/test/CodeGen/RISCV/macro-fusions.mir
@@ -1,7 +1,7 @@
# REQUIRES: asserts
# RUN: llc -mtriple=riscv64-linux-gnu -x=mir < %s \
# RUN: -debug-only=machine-scheduler -start-before=machine-scheduler 2>&1 \
-# RUN: -mattr=+lui-addi-fusion,+auipc-addi-fusion,+zexth-fusion,+zextw-fusion,+shifted-zextw-fusion,+ld-add-fusion \
+# RUN: -mattr=+lui-addi-fusion,+auipc-addi-fusion,+zexth-fusion,+zextw-fusion,+shifted-zextw-fusion,+ld-add-fusion,+shift-n-add-fusion \
# RUN: | FileCheck %s
# CHECK: lui_addi:%bb.0
@@ -174,3 +174,31 @@ body: |
$x11 = COPY %5
PseudoRET
...
+
+# CHECK: shift_n_add:%bb.0
+# CHECK: Macro fuse: {{.*}}SLLI - ADD
+---
+name: shift_n_add
+tracksRegLiveness: true
+body: |
+ bb.0.entry:
+ liveins: $x10, $x11
+ $x10 = SLLI $x10, 1
+ $x12 = XORI $x11, 3
+ $x10 = ADD $x10, $x11
+ PseudoRET
+...
+
+# CHECK: shift_n_add_commutable:%bb.0
+# CHECK: Macro fuse: {{.*}}SLLI - ADD
+---
+name: shift_n_add_commutable
+tracksRegLiveness: true
+body: |
+ bb.0.entry:
+ liveins: $x10, $x11
+ $x10 = SLLI $x10, 1
+ $x12 = XORI $x11, 3
+ $x10 = ADD $x11, $x10
+ PseudoRET
+...
diff --git a/llvm/test/TableGen/MacroFusion.td b/llvm/test/TableGen/MacroFusion.td
index 4aa6c8d9acb273..c1cdd0fee78ccd 100644
--- a/llvm/test/TableGen/MacroFusion.td
+++ b/llvm/test/TableGen/MacroFusion.td
@@ -34,6 +34,11 @@ let Namespace = "Test" in {
def Inst0 : TestInst<0>;
def Inst1 : TestInst<1>;
+def BothFusionPredicate: BothFusionPredicateWithMCInstPredicate<CheckRegOperand<0, X0>>;
+def TestBothFusionPredicate: Fusion<"test-both-fusion-predicate", "HasBothFusionPredicate",
+ "Test BothFusionPredicate",
+ [BothFusionPredicate]>;
+
def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
CheckOpcode<[Inst0]>,
CheckAll<[
@@ -45,6 +50,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
// CHECK-PREDICATOR-NEXT: #undef GET_Test_MACRO_FUSION_PRED_DECL
// CHECK-PREDICATOR-EMPTY:
// CHECK-PREDICATOR-NEXT: namespace llvm {
+// CHECK-PREDICATOR-NEXT: bool isTestBothFusionPredicate(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
// CHECK-PREDICATOR-NEXT: bool isTestFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
// CHECK-PREDICATOR-NEXT: } // end namespace llvm
// CHECK-PREDICATOR-EMPTY:
@@ -54,6 +60,24 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
// CHECK-PREDICATOR-NEXT: #undef GET_Test_MACRO_FUSION_PRED_IMPL
// CHECK-PREDICATOR-EMPTY:
// CHECK-PREDICATOR-NEXT: namespace llvm {
+// CHECK-PREDICATOR-NEXT: bool isTestBothFusionPredicate(
+// CHECK-PREDICATOR-NEXT: const TargetInstrInfo &TII,
+// CHECK-PREDICATOR-NEXT: const TargetSubtargetInfo &STI,
+// CHECK-PREDICATOR-NEXT: const MachineInstr *FirstMI,
+// CHECK-PREDICATOR-NEXT: const MachineInstr &SecondMI) {
+// CHECK-PREDICATOR-NEXT: auto &MRI = SecondMI.getMF()->getRegInfo();
+// CHECK-PREDICATOR-NEXT: {
+// CHECK-PREDICATOR-NEXT: const MachineInstr *MI = FirstMI;
+// CHECK-PREDICATOR-NEXT: if (MI->getOperand(0).getReg() != Test::X0)
+// CHECK-PREDICATOR-NEXT: return false;
+// CHECK-PREDICATOR-NEXT: }
+// CHECK-PREDICATOR-NEXT: {
+// CHECK-PREDICATOR-NEXT: const MachineInstr *MI = &SecondMI;
+// CHECK-PREDICATOR-NEXT: if (MI->getOperand(0).getReg() != Test::X0)
+// CHECK-PREDICATOR-NEXT: return false;
+// CHECK-PREDICATOR-NEXT: }
+// CHECK-PREDICATOR-NEXT: return true;
+// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: bool isTestFusion(
// CHECK-PREDICATOR-NEXT: const TargetInstrInfo &TII,
// CHECK-PREDICATOR-NEXT: const TargetSubtargetInfo &STI,
@@ -75,13 +99,15 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
// CHECK-PREDICATOR-NEXT: if (( MI->getOpcode() != Test::Inst0 ))
// CHECK-PREDICATOR-NEXT: return false;
// CHECK-PREDICATOR-NEXT: }
-// CHECK-PREDICATOR-NEXT: {
-// CHECK-PREDICATOR-NEXT: const MachineInstr *MI = &SecondMI;
-// CHECK-PREDICATOR-NEXT: if (!(
-// CHECK-PREDICATOR-NEXT: MI->getOperand(0).getReg().isVirtual()
-// CHECK-PREDICATOR-NEXT: || MI->getOperand(0).getReg() == MI->getOperand(1).getReg()
-// CHECK-PREDICATOR-NEXT: ))
-// CHECK-PREDICATOR-NEXT: return false;
+// CHECK-PREDICATOR-NEXT: if (!SecondMI.getOperand(0).getReg().isVirtual()) {
+// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg()) {
+// CHECK-PREDICATOR-NEXT: if (!SecondMI.getDesc().isCommutable())
+// CHECK-PREDICATOR-NEXT: return false;
+// CHECK-PREDICATOR-NEXT: unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
+// CHECK-PREDICATOR-NEXT: if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
+// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
+// CHECK-PREDICATOR-NEXT: return false;
+// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: {
// CHECK-PREDICATOR-NEXT: Register FirstDest = FirstMI->getOperand(0).getReg();
@@ -90,8 +116,14 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: if (!(FirstMI->getOperand(0).isReg() &&
// CHECK-PREDICATOR-NEXT: SecondMI.getOperand(1).isReg() &&
-// CHECK-PREDICATOR-NEXT: FirstMI->getOperand(0).getReg() == SecondMI.getOperand(1).getReg()))
-// CHECK-PREDICATOR-NEXT: return false;
+// CHECK-PREDICATOR-NEXT: FirstMI->getOperand(0).getReg() == SecondMI.getOperand(1).getReg())) {
+// CHECK-PREDICATOR-NEXT: if (!SecondMI.getDesc().isCommutable())
+// CHECK-PREDICATOR-NEXT: return false;
+// CHECK-PREDICATOR-NEXT: unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
+// CHECK-PREDICATOR-NEXT: if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
+// CHECK-PREDICATOR-NEXT: if (FirstMI->getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
+// CHECK-PREDICATOR-NEXT: return false;
+// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: return true;
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: } // end namespace llvm
@@ -106,6 +138,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
// CHECK-SUBTARGET: std::vector<MacroFusionPredTy> TestGenSubtargetInfo::getMacroFusions() const {
// CHECK-SUBTARGET-NEXT: std::vector<MacroFusionPredTy> Fusions;
+// CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestBothFusionPredicate)) Fusions.push_back(llvm::isTestBothFusionPredicate);
// CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestFusion)) Fusions.push_back(llvm::isTestFusion);
// CHECK-SUBTARGET-NEXT: return Fusions;
// CHECK-SUBTARGET-NEXT: }
diff --git a/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp b/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
index 78dcd4471ae747..1042dd9c2dbccf 100644
--- a/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
+++ b/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
@@ -152,8 +152,7 @@ void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
<< "if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))\n";
OS.indent(4) << " return false;\n";
OS.indent(2) << "}\n";
- } else if (Predicate->isSubClassOf(
- "FirstFusionPredicateWithMCInstPredicate")) {
+ } else if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
OS.indent(2) << "{\n";
OS.indent(4) << "const MachineInstr *MI = FirstMI;\n";
OS.indent(4) << "if (";
@@ -173,7 +172,7 @@ void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
PredicateExpander &PE,
raw_ostream &OS) {
- if (Predicate->isSubClassOf("SecondFusionPredicateWithMCInstPredicate")) {
+ if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
OS.indent(2) << "{\n";
OS.indent(4) << "const MachineInstr *MI = &SecondMI;\n";
OS.indent(4) << "if (";
@@ -183,9 +182,31 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
OS << ")\n";
OS.indent(4) << " return false;\n";
OS.indent(2) << "}\n";
+ } else if (Predicate->isSubClassOf("SameReg")) {
+ int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
+ int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
+
+ OS.indent(2) << "if (!SecondMI.getOperand(" << FirstOpIdx
+ << ").getReg().isVirtual()) {\n";
+ OS.indent(4) << "if (SecondMI.getOperand(" << FirstOpIdx
+ << ").getReg() != SecondMI.getOperand(" << SecondOpIdx
+ << ").getReg()) {\n";
+
+ OS.indent(6) << "if (!SecondMI.getDesc().isCommutable())\n";
+ OS.indent(6) << " return false;\n";
+
+ OS.indent(6) << "unsigned SrcOpIdx1 = " << SecondOpIdx
+ << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
+ OS.indent(6)
+ << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
+ OS.indent(6) << " if (SecondMI.getOperand(" << FirstOpIdx
+ << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
+ OS.indent(6) << " return false;\n";
+ OS.indent(4) << "}\n";
+ OS.indent(2) << "}\n";
} else {
PrintFatalError(Predicate->getLoc(),
- "Unsupported predicate for first instruction: " +
+ "Unsupported predicate for second instruction: " +
Predicate->getType()->getAsString());
}
}
@@ -196,9 +217,8 @@ void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
if (Predicate->isSubClassOf("FusionPredicateWithCode"))
OS << Predicate->getValueAsString("Predicate");
else if (Predicate->isSubClassOf("BothFusionPredicateWithMCInstPredicate")) {
- Record *MCPred = Predicate->getValueAsDef("Predicate");
- emitFirstPredicate(MCPred, PE, OS);
- emitSecondPredicate(MCPred, PE, OS);
+ emitFirstPredicate(Predicate, PE, OS);
+ emitSecondPredicate(Predicate, PE, OS);
} else if (Predicate->isSubClassOf("TieReg")) {
int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
@@ -208,8 +228,19 @@ void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
<< ").isReg() &&\n";
OS.indent(2) << " FirstMI->getOperand(" << FirstOpIdx
<< ").getReg() == SecondMI.getOperand(" << SecondOpIdx
- << ").getReg()))\n";
- OS.indent(2) << " return false;\n";
+ << ").getReg())) {\n";
+
+ OS.indent(4) << "if (!SecondMI.getDesc().isCommutable())\n";
+ OS.indent(4) << " return false;\n";
+
+ OS.indent(4) << "unsigned SrcOpIdx1 = " << SecondOpIdx
+ << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
+ OS.indent(4)
+ << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
+ OS.indent(4) << " if (FirstMI->getOperand(" << FirstOpIdx
+ << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
+ OS.indent(4) << " return false;\n";
+ OS.indent(2) << "}\n";
} else
PrintFatalError(Predicate->getLoc(),
"Unsupported predicate for both instruction: " +
|
74c60af
to
ac50115
Compare
ac50115
to
b426ef9
Compare
I added a |
Wouldn't that be implied by the opcode in the first place? The instruction will already have isCommutable set |
We can't simply use the class SingleFusion<string name, string fieldName, string desc,
Instruction firstInst, Instruction secondInst,
MCInstPredicate firstInstPred = TruePred,
MCInstPredicate secondInstPred = TruePred,
list<FusionPredicate> prolog = [],
list<FusionPredicate> epilog = []>
: SimpleFusion<name, fieldName, desc,
CheckAll<!listconcat(
[CheckOpcode<[firstInst]>],
[firstInstPred])>,
CheckAll<!listconcat(
[CheckOpcode<[secondInst]>],
[secondInstPred])>,
prolog, epilog> {
let IsCommutable = secondInst.isCommutable;
}
def ShiftNAddFusion
: SingleFusion<"shift-n-add-fusion", "HasShiftNAddFusion",
"Enable SLLI+ADD to be fused to shift left by 1/2/3 and add",
SLLI, ADD,
CheckAny<[CheckImmOperand<2, 1>,
CheckImmOperand<2, 2>,
CheckImmOperand<2, 3>]>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yet another commutable flag is a bit annoying/undiscoverable but OK I guess
If the second instruction is commutable, we should be able to check its commutable operands. A field `IsCommutable` is added to indicate whether we should generate code for checking commutable operands. Fixes llvm#82738
b426ef9
to
d7954a8
Compare
If the second instruction is commutable, we should be able to check
its commutable operands.
A simple RISCV fusion is contained in this PR to show the functionality
is correct, I may remove it when landing.
Fixes #82738