From a6bf7c058af39e249127b4e3a8052b2ec551ea07 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Wed, 26 Mar 2025 17:30:39 +0800 Subject: [PATCH] [mlir][scf] Remove redundant ensureTerminator for `scf.forall` The override function `ensureTerminator` ensures that the terminator `InParallelOp` has a region. However, if the terminator of `scf.forall` is not an `InParallelOp`, calling ensureTerminator causes a crash. Since the InParallelOp builder already guarantees the existence of a region, `ForallOp::ensureTerminator` is redundant and can be safely removed. --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 6 ------ mlir/lib/Dialect/SCF/IR/SCF.cpp | 13 ------------- mlir/test/Dialect/SCF/invalid.mlir | 12 ++++++++++++ 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 6f408b3c924de..b51b61b3d2cb9 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -605,12 +605,6 @@ def ForallOp : SCF_Op<"forall", [ /// Checks if the lbs are zeros and steps are ones. bool isNormalized(); - // The ensureTerminator method generated by SingleBlockImplicitTerminator is - // unaware of the fact that our terminator also needs a region to be - // well-formed. We override it here to ensure that we do the right thing. - static void ensureTerminator(Region & region, OpBuilder & builder, - Location loc); - InParallelOp getTerminator(); // Declare the shared_outs as inits/outs to DestinationStyleOpInterface. diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 1cfb866db0b51..0d6a853d5eca7 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1416,19 +1416,6 @@ bool ForallOp::isNormalized() { return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1); } -// The ensureTerminator method generated by SingleBlockImplicitTerminator is -// unaware of the fact that our terminator also needs a region to be -// well-formed. We override it here to ensure that we do the right thing. -void ForallOp::ensureTerminator(Region ®ion, OpBuilder &builder, - Location loc) { - OpTrait::SingleBlockImplicitTerminator::Impl< - ForallOp>::ensureTerminator(region, builder, loc); - auto terminator = - llvm::dyn_cast(region.front().getTerminator()); - if (terminator.getRegion().empty()) - builder.createBlock(&terminator.getRegion()); -} - InParallelOp ForallOp::getTerminator() { return cast(getBody()->getTerminator()); } diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index 76c785f3e6166..3d933544b8842 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -672,6 +672,18 @@ func.func @mismatched_mapping(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32> // ----- +func.func @forall_wrong_terminator_op() -> () { + %c100 = arith.constant 100 : index + // expected-error @+2 {{'scf.forall' op expects regions to end with 'scf.forall.in_parallel', found 'llvm.return'}} + // expected-note @below {{in custom textual format, the absence of terminator implies 'scf.forall.in_parallel'}} + scf.forall (%arg0) in (%c100) { + llvm.return + } + return +} + +// ----- + func.func @switch_wrong_case_count(%arg0: index) { // expected-error @below {{'scf.index_switch' op has 0 case regions but 1 case values}} "scf.index_switch"(%arg0) ({