Skip to content

Commit ac09b78

Browse files
authored
[mlir][scf] Remove redundant ensureTerminator for scf.forall (#133081)
The override function `ensureTerminator` ensures that the terminator `InParallelOp` has a region. However, if the terminator of `scf.forall` is not an `InParallelOp`, calling ensureTerminator causes a crash. Since the InParallelOp builder already guarantees the existence of a region, `ForallOp::ensureTerminator` is redundant and can be safely removed. Fixes #130019.
1 parent 17aca79 commit ac09b78

File tree

3 files changed

+12
-19
lines changed

3 files changed

+12
-19
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -605,12 +605,6 @@ def ForallOp : SCF_Op<"forall", [
605605
/// Checks if the lbs are zeros and steps are ones.
606606
bool isNormalized();
607607

608-
// The ensureTerminator method generated by SingleBlockImplicitTerminator is
609-
// unaware of the fact that our terminator also needs a region to be
610-
// well-formed. We override it here to ensure that we do the right thing.
611-
static void ensureTerminator(Region & region, OpBuilder & builder,
612-
Location loc);
613-
614608
InParallelOp getTerminator();
615609

616610
// Declare the shared_outs as inits/outs to DestinationStyleOpInterface.

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,19 +1416,6 @@ bool ForallOp::isNormalized() {
14161416
return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
14171417
}
14181418

1419-
// The ensureTerminator method generated by SingleBlockImplicitTerminator is
1420-
// unaware of the fact that our terminator also needs a region to be
1421-
// well-formed. We override it here to ensure that we do the right thing.
1422-
void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
1423-
Location loc) {
1424-
OpTrait::SingleBlockImplicitTerminator<InParallelOp>::Impl<
1425-
ForallOp>::ensureTerminator(region, builder, loc);
1426-
auto terminator =
1427-
llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
1428-
if (terminator.getRegion().empty())
1429-
builder.createBlock(&terminator.getRegion());
1430-
}
1431-
14321419
InParallelOp ForallOp::getTerminator() {
14331420
return cast<InParallelOp>(getBody()->getTerminator());
14341421
}

mlir/test/Dialect/SCF/invalid.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,18 @@ func.func @mismatched_mapping(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>
672672

673673
// -----
674674

675+
func.func @forall_wrong_terminator_op() -> () {
676+
%c100 = arith.constant 100 : index
677+
// expected-error @+2 {{'scf.forall' op expects regions to end with 'scf.forall.in_parallel', found 'llvm.return'}}
678+
// expected-note @below {{in custom textual format, the absence of terminator implies 'scf.forall.in_parallel'}}
679+
scf.forall (%arg0) in (%c100) {
680+
llvm.return
681+
}
682+
return
683+
}
684+
685+
// -----
686+
675687
func.func @switch_wrong_case_count(%arg0: index) {
676688
// expected-error @below {{'scf.index_switch' op has 0 case regions but 1 case values}}
677689
"scf.index_switch"(%arg0) ({

0 commit comments

Comments
 (0)