Skip to content

Commit 554e84a

Browse files
sitio-coutolanza
authored andcommitted
[CIR][IR] Implement cir.continue operation
Detaches the representation of the C/C++ `continue` statement into a separate operation. This simplifies mostly lowering and verifications related to `continue` statements, as well as the definition and lowering of the `cir.yield` operation. A few checks regarding region terminators were also removed from the lowering stage, since they are already enforced by MLIR. ghstack-source-id: 1810a48ada88fe7ef5638b0758a2298d9cfbdb8b Pull Request resolved: llvm/clangir#394
1 parent 0f41c1a commit 554e84a

File tree

9 files changed

+89
-60
lines changed

9 files changed

+89
-60
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -607,13 +607,12 @@ def ConditionOp : CIR_Op<"condition", [
607607

608608
def YieldOpKind_BK : I32EnumAttrCase<"Break", 1, "break">;
609609
def YieldOpKind_FT : I32EnumAttrCase<"Fallthrough", 2, "fallthrough">;
610-
def YieldOpKind_CE : I32EnumAttrCase<"Continue", 3, "continue">;
611610
def YieldOpKind_NS : I32EnumAttrCase<"NoSuspend", 4, "nosuspend">;
612611

613612
def YieldOpKind : I32EnumAttr<
614613
"YieldOpKind",
615614
"yield kind",
616-
[YieldOpKind_BK, YieldOpKind_FT, YieldOpKind_CE, YieldOpKind_NS]> {
615+
[YieldOpKind_BK, YieldOpKind_FT, YieldOpKind_NS]> {
617616
let cppNamespace = "::mlir::cir";
618617
}
619618

@@ -634,8 +633,6 @@ def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
634633
cannot be used if not dominated by these parent operations.
635634
- `fallthrough`: execution falls to the next region in `cir.switch` case list.
636635
Only available inside `cir.switch` regions.
637-
- `continue`: only allowed under `cir.loop`, continue execution to the next
638-
loop step.
639636
- `nosuspend`: specific to the `ready` region inside `cir.await` op, it makes
640637
control-flow to be transfered back to the parent, preventing suspension.
641638

@@ -657,11 +654,6 @@ def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
657654
}, ...
658655
]
659656

660-
cir.loop (cond : {...}, step : {...}) {
661-
...
662-
cir.yield continue
663-
}
664-
665657
cir.await(init, ready : {
666658
// Call std::suspend_always::await_ready
667659
%18 = cir.call @_ZNSt14suspend_always11await_readyEv(...)
@@ -718,9 +710,6 @@ def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
718710
bool isBreak() {
719711
return !isPlain() && *getKind() == YieldOpKind::Break;
720712
}
721-
bool isContinue() {
722-
return !isPlain() && *getKind() == YieldOpKind::Continue;
723-
}
724713
bool isNoSuspend() {
725714
return !isPlain() && *getKind() == YieldOpKind::NoSuspend;
726715
}
@@ -729,6 +718,20 @@ def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
729718
let hasVerifier = 1;
730719
}
731720

721+
//===----------------------------------------------------------------------===//
722+
// ContinueOp
723+
//===----------------------------------------------------------------------===//
724+
725+
def ContinueOp : CIR_Op<"continue", [Terminator]> {
726+
let summary = "C/C++ `continue` statement equivalent";
727+
let description = [{
728+
The `cir.continue` operation is used to continue execution to the next
729+
iteration of a loop. It is only allowed within `cir.loop` regions.
730+
}];
731+
let assemblyFormat = "attr-dict";
732+
let hasVerifier = 1;
733+
}
734+
732735
//===----------------------------------------------------------------------===//
733736
// ScopeOp
734737
//===----------------------------------------------------------------------===//
@@ -1166,7 +1169,7 @@ def LoopOp : CIR_Op<"loop",
11661169
`cir.loop` represents C/C++ loop forms. It defines 3 blocks:
11671170
- `cond`: region can contain multiple blocks, terminated by regular
11681171
`cir.yield` when control should yield back to the parent, and
1169-
`cir.yield continue` when execution continues to another region.
1172+
`cir.continue` when execution continues to the next region.
11701173
The region destination depends on the loop form specified.
11711174
- `step`: region with one block, containing code to compute the
11721175
loop step, must be terminated with `cir.yield`.
@@ -1181,7 +1184,7 @@ def LoopOp : CIR_Op<"loop",
11811184
// i = i + 1;
11821185
// }
11831186
cir.loop while(cond : {
1184-
cir.yield continue
1187+
cir.continue
11851188
}, step : {
11861189
cir.yield
11871190
}) {

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,11 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
585585
return create<mlir::cir::ConditionOp>(condition.getLoc(), condition);
586586
}
587587

588+
/// Create a continue operation.
589+
mlir::cir::ContinueOp createContinue(mlir::Location loc) {
590+
return create<mlir::cir::ContinueOp>(loc);
591+
}
592+
588593
mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
589594
mlir::Value src, mlir::Value len) {
590595
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,6 @@ mlir::LogicalResult CIRGenFunction::buildSimpleStmt(const Stmt *S,
301301
return buildGotoStmt(cast<GotoStmt>(*S));
302302
case Stmt::ContinueStmtClass:
303303
return buildContinueStmt(cast<ContinueStmt>(*S));
304-
305304
case Stmt::NullStmtClass:
306305
break;
307306

@@ -572,11 +571,7 @@ mlir::LogicalResult CIRGenFunction::buildLabel(const LabelDecl *D) {
572571

573572
mlir::LogicalResult
574573
CIRGenFunction::buildContinueStmt(const clang::ContinueStmt &S) {
575-
builder.create<YieldOp>(
576-
getLoc(S.getContinueLoc()),
577-
mlir::cir::YieldOpKindAttr::get(builder.getContext(),
578-
mlir::cir::YieldOpKind::Continue),
579-
mlir::ValueRange({}));
574+
builder.createContinue(getLoc(S.getContinueLoc()));
580575
return mlir::success();
581576
}
582577

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,16 @@ static void printConstantValue(OpAsmPrinter &p, cir::ConstantOp op,
341341

342342
OpFoldResult ConstantOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
343343

344+
//===----------------------------------------------------------------------===//
345+
// ContinueOp
346+
//===----------------------------------------------------------------------===//
347+
348+
LogicalResult ContinueOp::verify() {
349+
if (!this->getOperation()->getParentOfType<LoopOp>())
350+
return emitOpError("must be within a loop");
351+
return success();
352+
}
353+
344354
//===----------------------------------------------------------------------===//
345355
// CastOp
346356
//===----------------------------------------------------------------------===//
@@ -802,15 +812,6 @@ mlir::LogicalResult YieldOp::verify() {
802812
return false;
803813
};
804814

805-
auto isDominatedByLoop = [](Operation *parentOp) {
806-
while (!llvm::isa<cir::FuncOp>(parentOp)) {
807-
if (llvm::isa<cir::LoopOp>(parentOp))
808-
return true;
809-
parentOp = parentOp->getParentOp();
810-
}
811-
return false;
812-
};
813-
814815
if (isNoSuspend()) {
815816
if (!isDominatedByProperAwaitRegion(getOperation()->getParentOp(),
816817
getOperation()->getParentRegion()))
@@ -824,12 +825,6 @@ mlir::LogicalResult YieldOp::verify() {
824825
return mlir::success();
825826
}
826827

827-
if (isContinue()) {
828-
if (!isDominatedByLoop(getOperation()->getParentOp()))
829-
return emitOpError() << "shall be dominated by 'cir.loop'";
830-
return mlir::success();
831-
}
832-
833828
if (isFallthrough()) {
834829
if (!llvm::isa<SwitchOp>(getOperation()->getParentOp()))
835830
return emitOpError() << "fallthrough only expected within 'cir.switch'";

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
#include "mlir/IR/BuiltinDialect.h"
3030
#include "mlir/IR/BuiltinOps.h"
3131
#include "mlir/IR/BuiltinTypes.h"
32+
#include "mlir/IR/OpDefinition.h"
3233
#include "mlir/IR/Operation.h"
3334
#include "mlir/IR/Types.h"
3435
#include "mlir/IR/Value.h"
3536
#include "mlir/IR/ValueRange.h"
37+
#include "mlir/IR/Visitors.h"
3638
#include "mlir/Interfaces/DataLayoutInterfaces.h"
3739
#include "mlir/Pass/Pass.h"
3840
#include "mlir/Pass/PassManager.h"
@@ -65,6 +67,36 @@ using namespace llvm;
6567
namespace cir {
6668
namespace direct {
6769

70+
//===----------------------------------------------------------------------===//
71+
// Helper Methods
72+
//===----------------------------------------------------------------------===//
73+
74+
namespace {
75+
76+
/// Lowers operations with the terminator trait that have a single successor.
77+
void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
78+
mlir::ConversionPatternRewriter &rewriter) {
79+
assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
80+
mlir::OpBuilder::InsertionGuard guard(rewriter);
81+
rewriter.setInsertionPoint(op);
82+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, dest);
83+
}
84+
85+
/// Walks a region while skipping operations of type `Ops`. This ensures the
86+
/// callback is not applied to said operations and its children.
87+
template <typename... Ops>
88+
void walkRegionSkipping(mlir::Region &region,
89+
mlir::function_ref<void(mlir::Operation *)> callback) {
90+
region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
91+
if (isa<Ops...>(op))
92+
return mlir::WalkResult::skip();
93+
callback(op);
94+
return mlir::WalkResult::advance();
95+
});
96+
}
97+
98+
} // namespace
99+
68100
//===----------------------------------------------------------------------===//
69101
// Visitors for Lowering CIR Const Attributes
70102
//===----------------------------------------------------------------------===//
@@ -447,8 +479,15 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
447479

448480
lowerNestedYield(mlir::cir::YieldOpKind::Break, rewriter, bodyRegion,
449481
continueBlock);
450-
lowerNestedYield(mlir::cir::YieldOpKind::Continue, rewriter, bodyRegion,
451-
&stepBlock);
482+
483+
// Lower continue statements.
484+
mlir::Block &dest =
485+
(kind != LoopKind::For ? condFrontBlock : stepFrontBlock);
486+
walkRegionSkipping<mlir::cir::LoopOp>(
487+
loopOp.getBody(), [&](mlir::Operation *op) {
488+
if (isa<mlir::cir::ContinueOp>(op))
489+
lowerTerminator(op, &dest, rewriter);
490+
});
452491

453492
// Move loop op region contents to current CFG.
454493
rewriter.inlineRegionBefore(condRegion, continueBlock);
@@ -678,9 +717,8 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
678717
}
679718
};
680719

681-
static bool isBreakOrContinue(mlir::cir::YieldOp &op) {
682-
return op.getKind() == mlir::cir::YieldOpKind::Break ||
683-
op.getKind() == mlir::cir::YieldOpKind::Continue;
720+
static bool isBreak(mlir::cir::YieldOp &op) {
721+
return op.getKind() == mlir::cir::YieldOpKind::Break;
684722
}
685723

686724
class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
@@ -711,12 +749,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
711749
rewriter.setInsertionPointToEnd(thenAfterBody);
712750
if (auto thenYieldOp =
713751
mlir::dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator())) {
714-
if (!isBreakOrContinue(thenYieldOp)) // lowering of parent loop yields is
715-
// deferred to loop lowering
752+
if (!isBreak(thenYieldOp)) // lowering of parent loop yields is
753+
// deferred to loop lowering
716754
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
717755
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
718-
} else if (!mlir::dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator())) {
719-
llvm_unreachable("what are we terminating with?");
720756
}
721757

722758
rewriter.setInsertionPointToEnd(continueBlock);
@@ -742,12 +778,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
742778
rewriter.setInsertionPointToEnd(elseAfterBody);
743779
if (auto elseYieldOp =
744780
mlir::dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) {
745-
if (!isBreakOrContinue(elseYieldOp)) // lowering of parent loop yields
746-
// is deferred to loop lowering
781+
if (!isBreak(elseYieldOp)) // lowering of parent loop yields
782+
// is deferred to loop lowering
747783
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
748784
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
749-
} else if (!mlir::dyn_cast<mlir::cir::ReturnOp>(elseAfterBody->getTerminator())) {
750-
llvm_unreachable("what are we terminating with?");
751785
}
752786
}
753787

@@ -803,7 +837,7 @@ class CIRScopeOpLowering
803837
rewriter.setInsertionPointToEnd(afterBody);
804838
auto yieldOp = mlir::dyn_cast<mlir::cir::YieldOp>(afterBody->getTerminator());
805839

806-
if (yieldOp && !isBreakOrContinue(yieldOp)) {
840+
if (yieldOp && !isBreak(yieldOp)) {
807841
auto branchOp = rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
808842
yieldOp, yieldOp.getArgs(), continueBlock);
809843

@@ -1404,9 +1438,6 @@ class CIRSwitchOpLowering
14041438
case mlir::cir::YieldOpKind::Break:
14051439
rewriteYieldOp(rewriter, yieldOp, exitBlock);
14061440
break;
1407-
case mlir::cir::YieldOpKind::Continue: // Continue is handled only in
1408-
// loop lowering
1409-
break;
14101441
default:
14111442
return op->emitError("invalid yield kind in case statement");
14121443
}

clang/test/CIR/CodeGen/loop.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ void l4() {
180180
// CHECK-NEXT: %11 = cir.const(#cir.int<10> : !s32i) : !s32i
181181
// CHECK-NEXT: %12 = cir.cmp(lt, %10, %11) : !s32i, !cir.bool
182182
// CHECK-NEXT: cir.if %12 {
183-
// CHECK-NEXT: cir.yield continue
183+
// CHECK-NEXT: cir.continue
184184
// CHECK-NEXT: }
185185
// CHECK-NEXT: }
186186

clang/test/CIR/IR/invalid.cir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ cir.func @yieldbreak() {
8080
cir.func @yieldcontinue() {
8181
%0 = cir.const(#true) : !cir.bool
8282
cir.if %0 {
83-
cir.yield continue // expected-error {{shall be dominated by 'cir.loop'}}
83+
cir.continue // expected-error {{op must be within a loop}}
8484
}
8585
cir.return
8686
}

clang/test/CIR/IR/loop.cir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ cir.func @l0() {
5252
cir.store %6, %0 : !u32i, cir.ptr <!u32i>
5353
%7 = cir.const(#true) : !cir.bool
5454
cir.if %7 {
55-
cir.yield continue
55+
cir.continue
5656
}
5757
cir.yield
5858
}
@@ -118,7 +118,7 @@ cir.func @l0() {
118118
// CHECK-NEXT: cir.store %6, %0 : !u32i, cir.ptr <!u32i>
119119
// CHECK-NEXT: %7 = cir.const(#true) : !cir.bool
120120
// CHECK-NEXT: cir.if %7 {
121-
// CHECK-NEXT: cir.yield continue
121+
// CHECK-NEXT: cir.continue
122122
// CHECK-NEXT: }
123123
// CHECK-NEXT: cir.yield
124124
// CHECK-NEXT: }

clang/test/CIR/Lowering/loops-with-continue.cir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ module {
2727
%4 = cir.cmp(eq, %2, %3) : !s32i, !s32i
2828
%5 = cir.cast(int_to_bool, %4 : !s32i), !cir.bool
2929
cir.if %5 {
30-
cir.yield continue
30+
cir.continue
3131
}
3232
}
3333
}
@@ -107,7 +107,7 @@ module {
107107
%6 = cir.cmp(eq, %4, %5) : !s32i, !s32i
108108
%7 = cir.cast(int_to_bool, %6 : !s32i), !cir.bool
109109
cir.if %7 {
110-
cir.yield continue
110+
cir.continue
111111
}
112112
}
113113
}
@@ -189,7 +189,7 @@ cir.func @testWhile() {
189189
%6 = cir.cmp(eq, %4, %5) : !s32i, !s32i
190190
%7 = cir.cast(int_to_bool, %6 : !s32i), !cir.bool
191191
cir.if %7 {
192-
cir.yield continue
192+
cir.continue
193193
}
194194
}
195195
cir.yield
@@ -243,7 +243,7 @@ cir.func @testWhile() {
243243
%6 = cir.cmp(eq, %4, %5) : !s32i, !s32i
244244
%7 = cir.cast(int_to_bool, %6 : !s32i), !cir.bool
245245
cir.if %7 {
246-
cir.yield continue
246+
cir.continue
247247
}
248248
}
249249
cir.yield

0 commit comments

Comments
 (0)