diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp index 79ed85fa60607..2bca0d98ec68d 100644 --- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp +++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp @@ -36,7 +36,7 @@ struct DoLoopConversion : public mlir::OpRewritePattern { mlir::Value high = doLoopOp.getUpperBound(); assert(low && high && "must be a Value"); mlir::Value step = doLoopOp.getStep(); - llvm::SmallVector iterArgs; + mlir::SmallVector iterArgs; if (hasFinalValue) iterArgs.push_back(low); iterArgs.append(doLoopOp.getIterOperands().begin(), @@ -88,6 +88,73 @@ struct DoLoopConversion : public mlir::OpRewritePattern { } }; +struct IterWhileConversion : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(fir::IterWhileOp iterWhileOp, + mlir::PatternRewriter &rewriter) const override { + + mlir::Location loc = iterWhileOp.getLoc(); + mlir::Value lowerBound = iterWhileOp.getLowerBound(); + mlir::Value upperBound = iterWhileOp.getUpperBound(); + mlir::Value step = iterWhileOp.getStep(); + + mlir::Value okInit = iterWhileOp.getIterateIn(); + mlir::ValueRange iterArgs = iterWhileOp.getInitArgs(); + + mlir::SmallVector initVals; + initVals.push_back(lowerBound); + initVals.push_back(okInit); + initVals.append(iterArgs.begin(), iterArgs.end()); + + mlir::SmallVector loopTypes; + loopTypes.push_back(lowerBound.getType()); + loopTypes.push_back(okInit.getType()); + for (auto val : iterArgs) + loopTypes.push_back(val.getType()); + + auto scfWhileOp = + mlir::scf::WhileOp::create(rewriter, loc, loopTypes, initVals); + + auto &beforeBlock = *rewriter.createBlock( + &scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes, + mlir::SmallVector(loopTypes.size(), loc)); + + mlir::Region::BlockArgListType argsInBefore = + scfWhileOp.getBefore().getArguments(); + auto ivInBefore = argsInBefore[0]; + auto earlyExitInBefore = argsInBefore[1]; + + rewriter.setInsertionPointToStart(&beforeBlock); + + mlir::Value inductionCmp = mlir::arith::CmpIOp::create( + rewriter, loc, mlir::arith::CmpIPredicate::sle, ivInBefore, upperBound); + mlir::Value cond = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, + earlyExitInBefore); + + mlir::scf::ConditionOp::create(rewriter, loc, cond, argsInBefore); + + rewriter.moveBlockBefore(iterWhileOp.getBody(), &scfWhileOp.getAfter(), + scfWhileOp.getAfter().begin()); + + auto *afterBody = scfWhileOp.getAfterBody(); + auto resultOp = mlir::cast(afterBody->getTerminator()); + mlir::SmallVector results(resultOp->getOperands()); + mlir::Value ivInAfter = scfWhileOp.getAfterArguments()[0]; + + rewriter.setInsertionPointToStart(afterBody); + results[0] = mlir::arith::AddIOp::create(rewriter, loc, ivInAfter, step); + + rewriter.setInsertionPointToEnd(afterBody); + rewriter.replaceOpWithNewOp(resultOp, results); + + scfWhileOp->setAttrs(iterWhileOp->getAttrs()); + rewriter.replaceOp(iterWhileOp, scfWhileOp); + return mlir::success(); + } +}; + void copyBlockAndTransformResult(mlir::PatternRewriter &rewriter, mlir::Block &srcBlock, mlir::Block &dstBlock) { mlir::Operation *srcTerminator = srcBlock.getTerminator(); @@ -132,9 +199,10 @@ struct IfConversion : public mlir::OpRewritePattern { void FIRToSCFPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); mlir::ConversionTarget target(getContext()); - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](mlir::Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/flang/test/Fir/FirToSCF/iter-while.fir b/flang/test/Fir/FirToSCF/iter-while.fir new file mode 100644 index 0000000000000..0de7aabed120e --- /dev/null +++ b/flang/test/Fir/FirToSCF/iter-while.fir @@ -0,0 +1,99 @@ +// RUN: fir-opt %s --fir-to-scf | FileCheck %s + +// CHECK-LABEL: func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) { +// CHECK: %[[VAL_0:.*]] = arith.constant 11 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 22 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_3:.*]] = arith.constant true +// CHECK: %[[VAL_4:.*]] = arith.constant 123 : i16 +// CHECK: %[[VAL_5:.*]] = arith.constant 456 : i32 +// CHECK: %[[VAL_6:.*]]:4 = scf.while (%[[VAL_7:.*]] = %[[VAL_0]], %[[VAL_8:.*]] = %[[VAL_3]], %[[VAL_9:.*]] = %[[VAL_4]], %[[VAL_10:.*]] = %[[VAL_5]]) : (index, i1, i16, i32) -> (index, i1, i16, i32) { +// CHECK: %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_1]] : index +// CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_11]], %[[VAL_8]] : i1 +// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : index, i1, i16, i32 +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: i1, %[[VAL_15:.*]]: i16, %[[VAL_16:.*]]: i32): +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : index +// CHECK: %[[VAL_18:.*]] = arith.constant true +// CHECK: %[[VAL_19:.*]] = arith.constant 22 : i16 +// CHECK: %[[VAL_20:.*]] = arith.constant 33 : i32 +// CHECK: scf.yield %[[VAL_17]], %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, i1, i16, i32 +// CHECK: } +// CHECK: return %[[VAL_21:.*]]#0, %[[VAL_21]]#1, %[[VAL_21]]#2, %[[VAL_21]]#3 : index, i1, i16, i32 +// CHECK: } +func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) { + %lo = arith.constant 11 : index + %up = arith.constant 22 : index + %step = arith.constant 2 : index + %ok = arith.constant 1 : i1 + %val1 = arith.constant 123 : i16 + %val2 = arith.constant 456 : i32 + + %res:4 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%v1 = %val1, %v2 = %val2) -> (index, i1, i16, i32) { + %new_c = arith.constant 1 : i1 + %new_v1 = arith.constant 22 : i16 + %new_v2 = arith.constant 33 : i32 + fir.result %i, %new_c, %new_v1, %new_v2 : index, i1, i16, i32 + } + + return %res#0, %res#1, %res#2, %res#3 : index, i1, i16, i32 +} + +// CHECK-LABEL: func.func @test_simple_iterate_while_2( +// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: i32) -> (index, i1, i32) { +// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_1:.*]]:3 = scf.while (%[[VAL_2:.*]] = %[[ARG0]], %[[VAL_3:.*]] = %[[ARG2]], %[[VAL_4:.*]] = %[[ARG3]]) : (index, i1, i32) -> (index, i1, i32) { +// CHECK: %[[VAL_5:.*]] = arith.cmpi sle, %[[VAL_2]], %[[ARG1]] : index +// CHECK: %[[VAL_6:.*]] = arith.andi %[[VAL_5]], %[[VAL_3]] : i1 +// CHECK: scf.condition(%[[VAL_6]]) %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : index, i1, i32 +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_7:.*]]: index, %[[VAL_8:.*]]: i1, %[[VAL_9:.*]]: i32): +// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_7]], %[[VAL_0]] : index +// CHECK: %[[VAL_11:.*]] = arith.constant 123 : i32 +// CHECK: %[[VAL_12:.*]] = arith.constant true +// CHECK: scf.yield %[[VAL_10]], %[[VAL_12]], %[[VAL_11]] : index, i1, i32 +// CHECK: } +// CHECK: return %[[VAL_13:.*]]#0, %[[VAL_13]]#1, %[[VAL_13]]#2 : index, i1, i32 +// CHECK: } +func.func @test_simple_iterate_while_2(%start: index, %stop: index, %cond: i1, %val: i32) -> (index, i1, i32) { + %step = arith.constant 1 : index + + %res:3 = fir.iterate_while (%i = %start to %stop step %step) and (%ok = %cond) iter_args(%x = %val) -> (index, i1, i32) { + %new_x = arith.constant 123 : i32 + %new_ok = arith.constant 1 : i1 + fir.result %i, %new_ok, %new_x : index, i1, i32 + } + + return %res#0, %res#1, %res#2 : index, i1, i32 +} + +// CHECK-LABEL: func.func @test_zero_iterations() -> (index, i1, i8) { +// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 5 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant true +// CHECK: %[[VAL_4:.*]] = arith.constant 42 : i8 +// CHECK: %[[VAL_5:.*]]:3 = scf.while (%[[VAL_6:.*]] = %[[VAL_0]], %[[VAL_7:.*]] = %[[VAL_3]], %[[VAL_8:.*]] = %[[VAL_4]]) : (index, i1, i8) -> (index, i1, i8) { +// CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_1]] : index +// CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_9]], %[[VAL_7]] : i1 +// CHECK: scf.condition(%[[VAL_10]]) %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : index, i1, i8 +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: i1, %[[VAL_13:.*]]: i8): +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_2]] : index +// CHECK: scf.yield %[[VAL_14]], %[[VAL_12]], %[[VAL_13]] : index, i1, i8 +// CHECK: } +// CHECK: return %[[VAL_15:.*]]#0, %[[VAL_15]]#1, %[[VAL_15]]#2 : index, i1, i8 +// CHECK: } +func.func @test_zero_iterations() -> (index, i1, i8) { + %lo = arith.constant 10 : index + %up = arith.constant 5 : index + %step = arith.constant 1 : index + %ok = arith.constant 1 : i1 + %x = arith.constant 42 : i8 + + %res:3 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%xv = %x) -> (index, i1, i8) { + fir.result %i, %c, %xv : index, i1, i8 + } + + return %res#0, %res#1, %res#2 : index, i1, i8 +}