Skip to content

Commit 74be3a6

Browse files
committed
Add tests & fix call site
1 parent 0333e16 commit 74be3a6

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,8 @@ applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns,
170170
bool failed = false;
171171
for (Region &region : op->getRegions()) {
172172
bool regionChanged;
173-
failed |=
174-
applyPatternsAndFoldGreedily(region, patterns, config, &regionChanged)
175-
.failed();
173+
failed |= applyPatternsGreedily(region, patterns, config, &regionChanged)
174+
.failed();
176175
anyRegionChanged |= regionChanged;
177176
}
178177
if (changed)

mlir/test/Transforms/test-operation-folder.mlir

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s
22
// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s
3+
// RUN: mlir-opt -test-greedy-patterns='cse-constants=false' %s | FileCheck %s --check-prefix=NOCSE
4+
// RUN: mlir-opt -test-greedy-patterns='fold=false' %s | FileCheck %s --check-prefix=NOFOLD
35

46
func.func @foo() -> i32 {
57
%c42 = arith.constant 42 : i32
@@ -25,7 +27,8 @@ func.func @test_fold_before_previously_folded_op() -> (i32, i32) {
2527
}
2628

2729
func.func @test_dont_reorder_constants() -> (i32, i32, i32) {
28-
// Test that we don't reorder existing constants during folding if it isn't necessary.
30+
// Test that we don't reorder existing constants during folding if it isn't
31+
// necessary.
2932
// CHECK: %[[CST:.+]] = arith.constant 1
3033
// CHECK-NEXT: %[[CST:.+]] = arith.constant 2
3134
// CHECK-NEXT: %[[CST:.+]] = arith.constant 3
@@ -34,3 +37,37 @@ func.func @test_dont_reorder_constants() -> (i32, i32, i32) {
3437
%2 = arith.constant 3 : i32
3538
return %0, %1, %2 : i32, i32, i32
3639
}
40+
41+
func.func @test_dont_fold() -> (i32, i32, i32, i32, i32, i32) {
42+
// Test either not folding or deduping constants.
43+
44+
// CHECK-LABEL: test_dont_fold
45+
// CHECK-NOT: arith.constant 0
46+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0
47+
// CHECK-DAG: %[[CST:.+]] = arith.constant 1
48+
// CHECK-DAG: %[[CST:.+]] = arith.constant 2
49+
// CHECK-DAG: %[[CST:.+]] = arith.constant 3
50+
// CHECK-NEXT: return
51+
52+
// NOCSE-LABEL: test_dont_fold
53+
// NOCSE-DAG: arith.constant 0 : i32
54+
// NOCSE-DAG: arith.constant 1 : i32
55+
// NOCSE-DAG: arith.constant 2 : i32
56+
// NOCSE-DAG: arith.constant 1 : i32
57+
// NOCSE-DAG: arith.constant 2 : i32
58+
// NOCSE-DAG: arith.constant 3 : i32
59+
// NOCSE-NEXT: return
60+
61+
// NOFOLD-LABEL: test_dont_fold
62+
// NOFOLD: arith.addi
63+
// NOFOLD: arith.addi
64+
// NOFOLD: arith.addi
65+
66+
%c0 = arith.constant 0 : i32
67+
%c1 = arith.constant 1 : i32
68+
%c2 = arith.constant 2 : i32
69+
%0 = arith.addi %c0, %c1 : i32
70+
%1 = arith.addi %0, %c1 : i32
71+
%2 = arith.addi %c2, %c1 : i32
72+
return %0, %1, %2, %c0, %c1, %c2 : i32, i32, i32, i32, i32, i32
73+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,9 @@ struct TestGreedyPatternDriver
388388
GreedyRewriteConfig config;
389389
config.useTopDownTraversal = this->useTopDownTraversal;
390390
config.maxIterations = this->maxIterations;
391-
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
392-
config);
391+
config.fold = this->fold;
392+
config.cseConstants = this->cseConstants;
393+
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
393394
}
394395

395396
Option<bool> useTopDownTraversal{
@@ -400,6 +401,11 @@ struct TestGreedyPatternDriver
400401
*this, "max-iterations",
401402
llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
402403
llvm::cl::init(GreedyRewriteConfig().maxIterations)};
404+
Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
405+
llvm::cl::init(GreedyRewriteConfig().fold)};
406+
Option<bool> cseConstants{*this, "cse-constants",
407+
llvm::cl::desc("Whether to CSE constants"),
408+
llvm::cl::init(GreedyRewriteConfig().cseConstants)};
403409
};
404410

405411
struct DumpNotifications : public RewriterBase::Listener {

0 commit comments

Comments
 (0)