Skip to content

Commit eec4234

Browse files
committed
[CIR][IR] Implement cir.continue operation
Detaches the representation of the C/C++ `continue` statement into a separate operation. This simplifies mostly lowering and verifications related to `continue` statements, as well as the definition and lowering of the `cir.yield` operation. A few checks regarding region terminators were also removed from the lowering stage, since they are already enforced by MLIR. ghstack-source-id: 1810a48 Pull Request resolved: #394
1 parent 0e4bbf1 commit eec4234

File tree

9 files changed

+89
-61
lines changed

9 files changed

+89
-61
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -607,13 +607,12 @@ def ConditionOp : CIR_Op<"condition", [
607607

608608
def YieldOpKind_BK : I32EnumAttrCase<"Break", 1, "break">;
609609
def YieldOpKind_FT : I32EnumAttrCase<"Fallthrough", 2, "fallthrough">;
610-
def YieldOpKind_CE : I32EnumAttrCase<"Continue", 3, "continue">;
611610
def YieldOpKind_NS : I32EnumAttrCase<"NoSuspend", 4, "nosuspend">;
612611

613612
def YieldOpKind : I32EnumAttr<
614613
"YieldOpKind",
615614
"yield kind",
616-
[YieldOpKind_BK, YieldOpKind_FT, YieldOpKind_CE, YieldOpKind_NS]> {
615+
[YieldOpKind_BK, YieldOpKind_FT, YieldOpKind_NS]> {
617616
let cppNamespace = "::mlir::cir";
618617
}
619618

@@ -634,8 +633,6 @@ def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
634633
cannot be used if not dominated by these parent operations.
635634
- `fallthrough`: execution falls to the next region in `cir.switch` case list.
636635
Only available inside `cir.switch` regions.
637-
- `continue`: only allowed under `cir.loop`, continue execution to the next
638-
loop step.
639636
- `nosuspend`: specific to the `ready` region inside `cir.await` op, it makes
640637
control-flow to be transfered back to the parent, preventing suspension.
641638

@@ -657,11 +654,6 @@ def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
657654
}, ...
658655
]
659656

660-
cir.loop (cond : {...}, step : {...}) {
661-
...
662-
cir.yield continue
663-
}
664-
665657
cir.await(init, ready : {
666658
// Call std::suspend_always::await_ready
667659
%18 = cir.call @_ZNSt14suspend_always11await_readyEv(...)
@@ -718,9 +710,6 @@ def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
718710
bool isBreak() {
719711
return !isPlain() && *getKind() == YieldOpKind::Break;
720712
}
721-
bool isContinue() {
722-
return !isPlain() && *getKind() == YieldOpKind::Continue;
723-
}
724713
bool isNoSuspend() {
725714
return !isPlain() && *getKind() == YieldOpKind::NoSuspend;
726715
}
@@ -729,6 +718,20 @@ def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
729718
let hasVerifier = 1;
730719
}
731720

721+
//===----------------------------------------------------------------------===//
722+
// ContinueOp
723+
//===----------------------------------------------------------------------===//
724+
725+
def ContinueOp : CIR_Op<"continue", [Terminator]> {
726+
let summary = "C/C++ `continue` statement equivalent";
727+
let description = [{
728+
The `cir.continue` operation is used to continue execution to the next
729+
iteration of a loop. It is only allowed within `cir.loop` regions.
730+
}];
731+
let assemblyFormat = "attr-dict";
732+
let hasVerifier = 1;
733+
}
734+
732735
//===----------------------------------------------------------------------===//
733736
// ScopeOp
734737
//===----------------------------------------------------------------------===//
@@ -1166,7 +1169,7 @@ def LoopOp : CIR_Op<"loop",
11661169
`cir.loop` represents C/C++ loop forms. It defines 3 blocks:
11671170
- `cond`: region can contain multiple blocks, terminated by regular
11681171
`cir.yield` when control should yield back to the parent, and
1169-
`cir.yield continue` when execution continues to another region.
1172+
`cir.continue` when execution continues to the next region.
11701173
The region destination depends on the loop form specified.
11711174
- `step`: region with one block, containing code to compute the
11721175
loop step, must be terminated with `cir.yield`.
@@ -1181,7 +1184,7 @@ def LoopOp : CIR_Op<"loop",
11811184
// i = i + 1;
11821185
// }
11831186
cir.loop while(cond : {
1184-
cir.yield continue
1187+
cir.continue
11851188
}, step : {
11861189
cir.yield
11871190
}) {

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,11 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
583583
return create<mlir::cir::ConditionOp>(condition.getLoc(), condition);
584584
}
585585

586+
/// Create a continue operation.
587+
mlir::cir::ContinueOp createContinue(mlir::Location loc) {
588+
return create<mlir::cir::ContinueOp>(loc);
589+
}
590+
586591
mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
587592
mlir::Value src, mlir::Value len) {
588593
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,6 @@ mlir::LogicalResult CIRGenFunction::buildSimpleStmt(const Stmt *S,
299299
return buildGotoStmt(cast<GotoStmt>(*S));
300300
case Stmt::ContinueStmtClass:
301301
return buildContinueStmt(cast<ContinueStmt>(*S));
302-
303302
case Stmt::NullStmtClass:
304303
break;
305304

@@ -570,11 +569,7 @@ mlir::LogicalResult CIRGenFunction::buildLabel(const LabelDecl *D) {
570569

571570
mlir::LogicalResult
572571
CIRGenFunction::buildContinueStmt(const clang::ContinueStmt &S) {
573-
builder.create<YieldOp>(
574-
getLoc(S.getContinueLoc()),
575-
mlir::cir::YieldOpKindAttr::get(builder.getContext(),
576-
mlir::cir::YieldOpKind::Continue),
577-
mlir::ValueRange({}));
572+
builder.createContinue(getLoc(S.getContinueLoc()));
578573
return mlir::success();
579574
}
580575

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,16 @@ static void printConstantValue(OpAsmPrinter &p, cir::ConstantOp op,
336336

337337
OpFoldResult ConstantOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
338338

339+
//===----------------------------------------------------------------------===//
340+
// ContinueOp
341+
//===----------------------------------------------------------------------===//
342+
343+
LogicalResult ContinueOp::verify() {
344+
if (!this->getOperation()->getParentOfType<LoopOp>())
345+
return emitOpError("must be within a loop");
346+
return success();
347+
}
348+
339349
//===----------------------------------------------------------------------===//
340350
// CastOp
341351
//===----------------------------------------------------------------------===//
@@ -797,15 +807,6 @@ mlir::LogicalResult YieldOp::verify() {
797807
return false;
798808
};
799809

800-
auto isDominatedByLoop = [](Operation *parentOp) {
801-
while (!llvm::isa<cir::FuncOp>(parentOp)) {
802-
if (llvm::isa<cir::LoopOp>(parentOp))
803-
return true;
804-
parentOp = parentOp->getParentOp();
805-
}
806-
return false;
807-
};
808-
809810
if (isNoSuspend()) {
810811
if (!isDominatedByProperAwaitRegion(getOperation()->getParentOp(),
811812
getOperation()->getParentRegion()))
@@ -819,12 +820,6 @@ mlir::LogicalResult YieldOp::verify() {
819820
return mlir::success();
820821
}
821822

822-
if (isContinue()) {
823-
if (!isDominatedByLoop(getOperation()->getParentOp()))
824-
return emitOpError() << "shall be dominated by 'cir.loop'";
825-
return mlir::success();
826-
}
827-
828823
if (isFallthrough()) {
829824
if (!llvm::isa<SwitchOp>(getOperation()->getParentOp()))
830825
return emitOpError() << "fallthrough only expected within 'cir.switch'";

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
#include "mlir/IR/BuiltinDialect.h"
3030
#include "mlir/IR/BuiltinOps.h"
3131
#include "mlir/IR/BuiltinTypes.h"
32+
#include "mlir/IR/OpDefinition.h"
3233
#include "mlir/IR/Operation.h"
3334
#include "mlir/IR/Types.h"
3435
#include "mlir/IR/Value.h"
3536
#include "mlir/IR/ValueRange.h"
37+
#include "mlir/IR/Visitors.h"
3638
#include "mlir/Interfaces/DataLayoutInterfaces.h"
3739
#include "mlir/Pass/Pass.h"
3840
#include "mlir/Pass/PassManager.h"
@@ -65,6 +67,36 @@ using namespace llvm;
6567
namespace cir {
6668
namespace direct {
6769

70+
//===----------------------------------------------------------------------===//
71+
// Helper Methods
72+
//===----------------------------------------------------------------------===//
73+
74+
namespace {
75+
76+
/// Lowers operations with the terminator trait that have a single successor.
77+
void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
78+
mlir::ConversionPatternRewriter &rewriter) {
79+
assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
80+
mlir::OpBuilder::InsertionGuard guard(rewriter);
81+
rewriter.setInsertionPoint(op);
82+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, dest);
83+
}
84+
85+
/// Walks a region while skipping operations of type `Ops`. This ensures the
86+
/// callback is not applied to said operations and its children.
87+
template <typename... Ops>
88+
void walkRegionSkipping(mlir::Region &region,
89+
mlir::function_ref<void(mlir::Operation *)> callback) {
90+
region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
91+
if (isa<Ops...>(op))
92+
return mlir::WalkResult::skip();
93+
callback(op);
94+
return mlir::WalkResult::advance();
95+
});
96+
}
97+
98+
} // namespace
99+
68100
//===----------------------------------------------------------------------===//
69101
// Visitors for Lowering CIR Const Attributes
70102
//===----------------------------------------------------------------------===//
@@ -441,8 +473,15 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
441473

442474
lowerNestedYield(mlir::cir::YieldOpKind::Break, rewriter, bodyRegion,
443475
continueBlock);
444-
lowerNestedYield(mlir::cir::YieldOpKind::Continue, rewriter, bodyRegion,
445-
&stepBlock);
476+
477+
// Lower continue statements.
478+
mlir::Block &dest =
479+
(kind != LoopKind::For ? condFrontBlock : stepFrontBlock);
480+
walkRegionSkipping<mlir::cir::LoopOp>(
481+
loopOp.getBody(), [&](mlir::Operation *op) {
482+
if (isa<mlir::cir::ContinueOp>(op))
483+
lowerTerminator(op, &dest, rewriter);
484+
});
446485

447486
// Move loop op region contents to current CFG.
448487
rewriter.inlineRegionBefore(condRegion, continueBlock);
@@ -672,9 +711,8 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
672711
}
673712
};
674713

675-
static bool isBreakOrContinue(mlir::cir::YieldOp &op) {
676-
return op.getKind() == mlir::cir::YieldOpKind::Break ||
677-
op.getKind() == mlir::cir::YieldOpKind::Continue;
714+
static bool isBreak(mlir::cir::YieldOp &op) {
715+
return op.getKind() == mlir::cir::YieldOpKind::Break;
678716
}
679717

680718
class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
@@ -705,12 +743,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
705743
rewriter.setInsertionPointToEnd(thenAfterBody);
706744
if (auto thenYieldOp =
707745
dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator())) {
708-
if (!isBreakOrContinue(thenYieldOp)) // lowering of parent loop yields is
709-
// deferred to loop lowering
746+
if (!isBreak(thenYieldOp)) // lowering of parent loop yields is
747+
// deferred to loop lowering
710748
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
711749
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
712-
} else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator())) {
713-
llvm_unreachable("what are we terminating with?");
714750
}
715751

716752
rewriter.setInsertionPointToEnd(continueBlock);
@@ -736,13 +772,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
736772
rewriter.setInsertionPointToEnd(elseAfterBody);
737773
if (auto elseYieldOp =
738774
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) {
739-
if (!isBreakOrContinue(elseYieldOp)) // lowering of parent loop yields
740-
// is deferred to loop lowering
775+
if (!isBreak(elseYieldOp)) // lowering of parent loop yields
776+
// is deferred to loop lowering
741777
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
742778
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
743-
} else if (!dyn_cast<mlir::cir::ReturnOp>(
744-
elseAfterBody->getTerminator())) {
745-
llvm_unreachable("what are we terminating with?");
746779
}
747780
}
748781

@@ -798,7 +831,7 @@ class CIRScopeOpLowering
798831
rewriter.setInsertionPointToEnd(afterBody);
799832
auto yieldOp = dyn_cast<mlir::cir::YieldOp>(afterBody->getTerminator());
800833

801-
if (yieldOp && !isBreakOrContinue(yieldOp)) {
834+
if (yieldOp && !isBreak(yieldOp)) {
802835
auto branchOp = rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
803836
yieldOp, yieldOp.getArgs(), continueBlock);
804837

@@ -1400,9 +1433,6 @@ class CIRSwitchOpLowering
14001433
case mlir::cir::YieldOpKind::Break:
14011434
rewriteYieldOp(rewriter, yieldOp, exitBlock);
14021435
break;
1403-
case mlir::cir::YieldOpKind::Continue: // Continue is handled only in
1404-
// loop lowering
1405-
break;
14061436
default:
14071437
return op->emitError("invalid yield kind in case statement");
14081438
}

clang/test/CIR/CodeGen/loop.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ void l4() {
180180
// CHECK-NEXT: %11 = cir.const(#cir.int<10> : !s32i) : !s32i
181181
// CHECK-NEXT: %12 = cir.cmp(lt, %10, %11) : !s32i, !cir.bool
182182
// CHECK-NEXT: cir.if %12 {
183-
// CHECK-NEXT: cir.yield continue
183+
// CHECK-NEXT: cir.continue
184184
// CHECK-NEXT: }
185185
// CHECK-NEXT: }
186186

clang/test/CIR/IR/invalid.cir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ cir.func @yieldbreak() {
7979
cir.func @yieldcontinue() {
8080
%0 = cir.const(#true) : !cir.bool
8181
cir.if %0 {
82-
cir.yield continue // expected-error {{shall be dominated by 'cir.loop'}}
82+
cir.continue // expected-error {{op must be within a loop}}
8383
}
8484
cir.return
8585
}

clang/test/CIR/IR/loop.cir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ cir.func @l0() {
5252
cir.store %6, %0 : !u32i, cir.ptr <!u32i>
5353
%7 = cir.const(#true) : !cir.bool
5454
cir.if %7 {
55-
cir.yield continue
55+
cir.continue
5656
}
5757
cir.yield
5858
}
@@ -118,7 +118,7 @@ cir.func @l0() {
118118
// CHECK-NEXT: cir.store %6, %0 : !u32i, cir.ptr <!u32i>
119119
// CHECK-NEXT: %7 = cir.const(#true) : !cir.bool
120120
// CHECK-NEXT: cir.if %7 {
121-
// CHECK-NEXT: cir.yield continue
121+
// CHECK-NEXT: cir.continue
122122
// CHECK-NEXT: }
123123
// CHECK-NEXT: cir.yield
124124
// CHECK-NEXT: }

clang/test/CIR/Lowering/loops-with-continue.cir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ module {
2727
%4 = cir.cmp(eq, %2, %3) : !s32i, !s32i
2828
%5 = cir.cast(int_to_bool, %4 : !s32i), !cir.bool
2929
cir.if %5 {
30-
cir.yield continue
30+
cir.continue
3131
}
3232
}
3333
}
@@ -107,7 +107,7 @@ module {
107107
%6 = cir.cmp(eq, %4, %5) : !s32i, !s32i
108108
%7 = cir.cast(int_to_bool, %6 : !s32i), !cir.bool
109109
cir.if %7 {
110-
cir.yield continue
110+
cir.continue
111111
}
112112
}
113113
}
@@ -189,7 +189,7 @@ cir.func @testWhile() {
189189
%6 = cir.cmp(eq, %4, %5) : !s32i, !s32i
190190
%7 = cir.cast(int_to_bool, %6 : !s32i), !cir.bool
191191
cir.if %7 {
192-
cir.yield continue
192+
cir.continue
193193
}
194194
}
195195
cir.yield
@@ -243,7 +243,7 @@ cir.func @testWhile() {
243243
%6 = cir.cmp(eq, %4, %5) : !s32i, !s32i
244244
%7 = cir.cast(int_to_bool, %6 : !s32i), !cir.bool
245245
cir.if %7 {
246-
cir.yield continue
246+
cir.continue
247247
}
248248
}
249249
cir.yield

0 commit comments

Comments
 (0)