Skip to content

[mlir][scf] Remove redundant ensureTerminator for scf.forall #133081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2025

Conversation

CoTinker
Copy link
Contributor

The override function ensureTerminator ensures that the terminator InParallelOp has a region. However, if the terminator of scf.forall is not an InParallelOp, calling ensureTerminator causes a crash. Since the InParallelOp builder already guarantees the existence of a region, ForallOp::ensureTerminator is redundant and can be safely removed. Fixes #130019.

The override function `ensureTerminator` ensures that the terminator
`InParallelOp` has a region. However, if the terminator of
`scf.forall` is not an `InParallelOp`, calling ensureTerminator causes a
crash. Since the InParallelOp builder already guarantees the existence of a
region, `ForallOp::ensureTerminator` is redundant and can be safely
removed.
@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: Longsheng Mou (CoTinker)

Changes

The override function ensureTerminator ensures that the terminator InParallelOp has a region. However, if the terminator of scf.forall is not an InParallelOp, calling ensureTerminator causes a crash. Since the InParallelOp builder already guarantees the existence of a region, ForallOp::ensureTerminator is redundant and can be safely removed. Fixes #130019.


Full diff: https://github.com/llvm/llvm-project/pull/133081.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (-6)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (-13)
  • (modified) mlir/test/Dialect/SCF/invalid.mlir (+12)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 6f408b3c924de..b51b61b3d2cb9 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -605,12 +605,6 @@ def ForallOp : SCF_Op<"forall", [
     /// Checks if the lbs are zeros and steps are ones.
     bool isNormalized();
 
-    // The ensureTerminator method generated by SingleBlockImplicitTerminator is
-    // unaware of the fact that our terminator also needs a region to be
-    // well-formed. We override it here to ensure that we do the right thing.
-    static void ensureTerminator(Region & region, OpBuilder & builder,
-                                 Location loc);
-
     InParallelOp getTerminator();
 
     // Declare the shared_outs as inits/outs to DestinationStyleOpInterface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 1cfb866db0b51..0d6a853d5eca7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1416,19 +1416,6 @@ bool ForallOp::isNormalized() {
   return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
 }
 
-// The ensureTerminator method generated by SingleBlockImplicitTerminator is
-// unaware of the fact that our terminator also needs a region to be
-// well-formed. We override it here to ensure that we do the right thing.
-void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
-                                Location loc) {
-  OpTrait::SingleBlockImplicitTerminator<InParallelOp>::Impl<
-      ForallOp>::ensureTerminator(region, builder, loc);
-  auto terminator =
-      llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
-  if (terminator.getRegion().empty())
-    builder.createBlock(&terminator.getRegion());
-}
-
 InParallelOp ForallOp::getTerminator() {
   return cast<InParallelOp>(getBody()->getTerminator());
 }
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 76c785f3e6166..3d933544b8842 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -672,6 +672,18 @@ func.func @mismatched_mapping(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>
 
 // -----
 
+func.func @forall_wrong_terminator_op() -> () {
+  %c100 = arith.constant 100 : index
+  // expected-error @+2 {{'scf.forall' op expects regions to end with 'scf.forall.in_parallel', found 'llvm.return'}}
+  // expected-note @below {{in custom textual format, the absence of terminator implies 'scf.forall.in_parallel'}}
+  scf.forall (%arg0) in (%c100) {
+    llvm.return
+  }
+  return
+}
+
+// -----
+
 func.func @switch_wrong_case_count(%arg0: index) {
   // expected-error @below {{'scf.index_switch' op has 0 case regions but 1 case values}}
   "scf.index_switch"(%arg0) ({

@CoTinker
Copy link
Contributor Author

The InParallelOp builder:

void InParallelOp::build(OpBuilder &b, OperationState &result) {
OpBuilder::InsertionGuard g(b);
Region *bodyRegion = result.addRegion();
b.createBlock(bodyRegion);
}

@CoTinker CoTinker merged commit ac09b78 into llvm:main Mar 27, 2025
14 checks passed
@CoTinker CoTinker deleted the scf_forall branch March 27, 2025 12:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] Parser crash 'scf.forall'
3 participants