29
29
#include " mlir/IR/BuiltinDialect.h"
30
30
#include " mlir/IR/BuiltinOps.h"
31
31
#include " mlir/IR/BuiltinTypes.h"
32
+ #include " mlir/IR/OpDefinition.h"
32
33
#include " mlir/IR/Operation.h"
33
34
#include " mlir/IR/Types.h"
34
35
#include " mlir/IR/Value.h"
35
36
#include " mlir/IR/ValueRange.h"
37
+ #include " mlir/IR/Visitors.h"
36
38
#include " mlir/Interfaces/DataLayoutInterfaces.h"
37
39
#include " mlir/Pass/Pass.h"
38
40
#include " mlir/Pass/PassManager.h"
@@ -65,6 +67,36 @@ using namespace llvm;
65
67
namespace cir {
66
68
namespace direct {
67
69
70
+ // ===----------------------------------------------------------------------===//
71
+ // Helper Methods
72
+ // ===----------------------------------------------------------------------===//
73
+
74
+ namespace {
75
+
76
+ // / Lowers operations with the terminator trait that have a single successor.
77
+ void lowerTerminator (mlir::Operation *op, mlir::Block *dest,
78
+ mlir::ConversionPatternRewriter &rewriter) {
79
+ assert (op->hasTrait <mlir::OpTrait::IsTerminator>() && " not a terminator" );
80
+ mlir::OpBuilder::InsertionGuard guard (rewriter);
81
+ rewriter.setInsertionPoint (op);
82
+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(op, dest);
83
+ }
84
+
85
+ // / Walks a region while skipping operations of type `Ops`. This ensures the
86
+ // / callback is not applied to said operations and its children.
87
+ template <typename ... Ops>
88
+ void walkRegionSkipping (mlir::Region ®ion,
89
+ mlir::function_ref<void (mlir::Operation *)> callback) {
90
+ region.walk <mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
91
+ if (isa<Ops...>(op))
92
+ return mlir::WalkResult::skip ();
93
+ callback (op);
94
+ return mlir::WalkResult::advance ();
95
+ });
96
+ }
97
+
98
+ } // namespace
99
+
68
100
// ===----------------------------------------------------------------------===//
69
101
// Visitors for Lowering CIR Const Attributes
70
102
// ===----------------------------------------------------------------------===//
@@ -441,8 +473,15 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
441
473
442
474
lowerNestedYield (mlir::cir::YieldOpKind::Break, rewriter, bodyRegion,
443
475
continueBlock);
444
- lowerNestedYield (mlir::cir::YieldOpKind::Continue, rewriter, bodyRegion,
445
- &stepBlock);
476
+
477
+ // Lower continue statements.
478
+ mlir::Block &dest =
479
+ (kind != LoopKind::For ? condFrontBlock : stepFrontBlock);
480
+ walkRegionSkipping<mlir::cir::LoopOp>(
481
+ loopOp.getBody (), [&](mlir::Operation *op) {
482
+ if (isa<mlir::cir::ContinueOp>(op))
483
+ lowerTerminator (op, &dest, rewriter);
484
+ });
446
485
447
486
// Move loop op region contents to current CFG.
448
487
rewriter.inlineRegionBefore (condRegion, continueBlock);
@@ -672,9 +711,8 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
672
711
}
673
712
};
674
713
675
- static bool isBreakOrContinue (mlir::cir::YieldOp &op) {
676
- return op.getKind () == mlir::cir::YieldOpKind::Break ||
677
- op.getKind () == mlir::cir::YieldOpKind::Continue;
714
+ static bool isBreak (mlir::cir::YieldOp &op) {
715
+ return op.getKind () == mlir::cir::YieldOpKind::Break;
678
716
}
679
717
680
718
class CIRIfLowering : public mlir ::OpConversionPattern<mlir::cir::IfOp> {
@@ -705,12 +743,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
705
743
rewriter.setInsertionPointToEnd (thenAfterBody);
706
744
if (auto thenYieldOp =
707
745
dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator ())) {
708
- if (!isBreakOrContinue (thenYieldOp)) // lowering of parent loop yields is
709
- // deferred to loop lowering
746
+ if (!isBreak (thenYieldOp)) // lowering of parent loop yields is
747
+ // deferred to loop lowering
710
748
rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
711
749
thenYieldOp, thenYieldOp.getArgs (), continueBlock);
712
- } else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator ())) {
713
- llvm_unreachable (" what are we terminating with?" );
714
750
}
715
751
716
752
rewriter.setInsertionPointToEnd (continueBlock);
@@ -736,13 +772,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
736
772
rewriter.setInsertionPointToEnd (elseAfterBody);
737
773
if (auto elseYieldOp =
738
774
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator ())) {
739
- if (!isBreakOrContinue (elseYieldOp)) // lowering of parent loop yields
740
- // is deferred to loop lowering
775
+ if (!isBreak (elseYieldOp)) // lowering of parent loop yields
776
+ // is deferred to loop lowering
741
777
rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
742
778
elseYieldOp, elseYieldOp.getArgs (), continueBlock);
743
- } else if (!dyn_cast<mlir::cir::ReturnOp>(
744
- elseAfterBody->getTerminator ())) {
745
- llvm_unreachable (" what are we terminating with?" );
746
779
}
747
780
}
748
781
@@ -798,7 +831,7 @@ class CIRScopeOpLowering
798
831
rewriter.setInsertionPointToEnd (afterBody);
799
832
auto yieldOp = dyn_cast<mlir::cir::YieldOp>(afterBody->getTerminator ());
800
833
801
- if (yieldOp && !isBreakOrContinue (yieldOp)) {
834
+ if (yieldOp && !isBreak (yieldOp)) {
802
835
auto branchOp = rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
803
836
yieldOp, yieldOp.getArgs (), continueBlock);
804
837
@@ -1400,9 +1433,6 @@ class CIRSwitchOpLowering
1400
1433
case mlir::cir::YieldOpKind::Break:
1401
1434
rewriteYieldOp (rewriter, yieldOp, exitBlock);
1402
1435
break ;
1403
- case mlir::cir::YieldOpKind::Continue: // Continue is handled only in
1404
- // loop lowering
1405
- break ;
1406
1436
default :
1407
1437
return op->emitError (" invalid yield kind in case statement" );
1408
1438
}
0 commit comments