Skip to content

Conversation

terapines-osc-mlir
Copy link
Contributor

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Aug 7, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Terapines MLIR (terapines-osc-mlir)

Changes

This commmit is a supplement for #140374.
RFC:https://discourse.llvm.org/t/rfc-add-fir-affine-optimization-fir-pass-pipeline/86190/6


Full diff: https://github.com/llvm/llvm-project/pull/152439.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/Transforms/FIRToSCF.cpp (+88-2)
  • (added) flang/test/Fir/FirToSCF/iter-while.fir (+99)
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
index 1902757e83bf3..b779a21089549 100644
--- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -88,6 +88,91 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
   }
 };
 
+struct IterWhileConversion : public OpRewritePattern<fir::IterWhileOp> {
+  using OpRewritePattern<fir::IterWhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(fir::IterWhileOp iterWhileOp,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = iterWhileOp.getLoc();
+    Value lowerBound = iterWhileOp.getLowerBound();
+    Value upperBound = iterWhileOp.getUpperBound();
+    Value step = iterWhileOp.getStep();
+
+    Value okInit = iterWhileOp.getIterateIn();
+    ValueRange iterArgs = iterWhileOp.getInitArgs();
+
+    SmallVector<Value> initVals;
+    initVals.push_back(lowerBound);
+    initVals.push_back(okInit);
+    initVals.append(iterArgs.begin(), iterArgs.end());
+
+    SmallVector<Type> loopTypes;
+    loopTypes.push_back(lowerBound.getType());
+    loopTypes.push_back(okInit.getType());
+    for (auto val : iterArgs)
+      loopTypes.push_back(val.getType());
+
+    auto scfWhileOp = scf::WhileOp::create(rewriter, loc, loopTypes, initVals);
+    rewriter.createBlock(&scfWhileOp.getBefore(), scfWhileOp.getBefore().end(),
+                         loopTypes,
+                         SmallVector<Location>(loopTypes.size(), loc));
+
+    rewriter.createBlock(&scfWhileOp.getAfter(), scfWhileOp.getAfter().end(),
+                         loopTypes,
+                         SmallVector<Location>(loopTypes.size(), loc));
+
+    {
+      rewriter.setInsertionPointToStart(&scfWhileOp.getBefore().front());
+      auto args = scfWhileOp.getBefore().getArguments();
+      auto iv = args[0];
+      auto ok = args[1];
+
+      Value inductionCmp = mlir::arith::CmpIOp::create(
+          rewriter, loc, mlir::arith::CmpIPredicate::sle, iv, upperBound);
+      Value cmp = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, ok);
+
+      mlir::scf::ConditionOp::create(rewriter, loc, cmp, args);
+    }
+
+    {
+      rewriter.setInsertionPointToStart(&scfWhileOp.getAfter().front());
+      auto args = scfWhileOp.getAfter().getArguments();
+      auto iv = args[0];
+
+      mlir::IRMapping mapping;
+      for (auto [oldArg, newVal] :
+           llvm::zip(iterWhileOp.getBody()->getArguments(), args))
+        mapping.map(oldArg, newVal);
+
+      for (auto &op : iterWhileOp.getBody()->without_terminator())
+        rewriter.clone(op, mapping);
+
+      auto resultOp =
+          cast<fir::ResultOp>(iterWhileOp.getBody()->getTerminator());
+      auto results = resultOp.getResults();
+
+      SmallVector<Value> yieldedVals;
+
+      Value nextIv = mlir::arith::AddIOp::create(rewriter, loc, iv, step);
+      yieldedVals.push_back(nextIv);
+
+      for (auto val : results.drop_front()) {
+        if (mapping.contains(val)) {
+          yieldedVals.push_back(mapping.lookup(val));
+        } else {
+          yieldedVals.push_back(val);
+        }
+      }
+
+      mlir::scf::YieldOp::create(rewriter, loc, yieldedVals);
+    }
+
+    rewriter.replaceOp(iterWhileOp, scfWhileOp);
+    return success();
+  }
+};
+
 void copyBlockAndTransformResult(PatternRewriter &rewriter, Block &srcBlock,
                                  Block &dstBlock) {
   Operation *srcTerminator = srcBlock.getTerminator();
@@ -130,9 +215,10 @@ struct IfConversion : public OpRewritePattern<fir::IfOp> {
 
 void FIRToSCFPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
-  patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
+  patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>(
+      patterns.getContext());
   ConversionTarget target(getContext());
-  target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
+  target.addIllegalOp<fir::DoLoopOp, fir::IterWhileOp, fir::IfOp>();
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/flang/test/Fir/FirToSCF/iter-while.fir b/flang/test/Fir/FirToSCF/iter-while.fir
new file mode 100644
index 0000000000000..a5de48f2ba848
--- /dev/null
+++ b/flang/test/Fir/FirToSCF/iter-while.fir
@@ -0,0 +1,99 @@
+// RUN: fir-opt %s --fir-to-scf | FileCheck %s
+
+// CHECK-LABEL:   func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 11 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 22 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant true
+// CHECK:           %[[VAL_4:.*]] = arith.constant 123 : i16
+// CHECK:           %[[VAL_5:.*]] = arith.constant 456 : i32
+// CHECK:           %[[VAL_6:.*]]:4 = scf.while (%[[VAL_7:.*]] = %[[VAL_0]], %[[VAL_8:.*]] = %[[VAL_3]], %[[VAL_9:.*]] = %[[VAL_4]], %[[VAL_10:.*]] = %[[VAL_5]]) : (index, i1, i16, i32) -> (index, i1, i16, i32) {
+// CHECK:             %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_1]] : index
+// CHECK:             %[[VAL_12:.*]] = arith.andi %[[VAL_11]], %[[VAL_8]] : i1
+// CHECK:             scf.condition(%[[VAL_12]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : index, i1, i16, i32
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: i1, %[[VAL_15:.*]]: i16, %[[VAL_16:.*]]: i32):
+// CHECK:             %[[VAL_17:.*]] = arith.constant true
+// CHECK:             %[[VAL_18:.*]] = arith.constant 22 : i16
+// CHECK:             %[[VAL_19:.*]] = arith.constant 33 : i32
+// CHECK:             %[[VAL_20:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : index
+// CHECK:             scf.yield %[[VAL_20]], %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : index, i1, i16, i32
+// CHECK:           }
+// CHECK:           return %[[VAL_21:.*]]#0, %[[VAL_21]]#1, %[[VAL_21]]#2, %[[VAL_21]]#3 : index, i1, i16, i32
+// CHECK:         }
+func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) {
+  %lo = arith.constant 11 : index
+  %up = arith.constant 22 : index
+  %step = arith.constant 2 : index
+  %ok = arith.constant 1 : i1
+  %val1 = arith.constant 123 : i16
+  %val2 = arith.constant 456 : i32
+
+  %res:4 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%v1 = %val1, %v2 = %val2) -> (index, i1, i16, i32) {
+    %new_c = arith.constant 1 : i1
+    %new_v1 = arith.constant 22 : i16
+    %new_v2 = arith.constant 33 : i32
+    fir.result %i, %new_c, %new_v1, %new_v2 : index, i1, i16, i32
+  }
+
+  return %res#0, %res#1, %res#2, %res#3 : index, i1, i16, i32
+}
+
+// CHECK-LABEL:   func.func @test_simple_iterate_while_2(
+// CHECK-SAME:        %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: i32) -> (index, i1, i32) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]]:3 = scf.while (%[[VAL_2:.*]] = %[[ARG0]], %[[VAL_3:.*]] = %[[ARG2]], %[[VAL_4:.*]] = %[[ARG3]]) : (index, i1, i32) -> (index, i1, i32) {
+// CHECK:             %[[VAL_5:.*]] = arith.cmpi sle, %[[VAL_2]], %[[ARG1]] : index
+// CHECK:             %[[VAL_6:.*]] = arith.andi %[[VAL_5]], %[[VAL_3]] : i1
+// CHECK:             scf.condition(%[[VAL_6]]) %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : index, i1, i32
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_7:.*]]: index, %[[VAL_8:.*]]: i1, %[[VAL_9:.*]]: i32):
+// CHECK:             %[[VAL_10:.*]] = arith.constant 123 : i32
+// CHECK:             %[[VAL_11:.*]] = arith.constant true
+// CHECK:             %[[VAL_12:.*]] = arith.addi %[[VAL_7]], %[[VAL_0]] : index
+// CHECK:             scf.yield %[[VAL_12]], %[[VAL_11]], %[[VAL_10]] : index, i1, i32
+// CHECK:           }
+// CHECK:           return %[[VAL_13:.*]]#0, %[[VAL_13]]#1, %[[VAL_13]]#2 : index, i1, i32
+// CHECK:         }
+func.func @test_simple_iterate_while_2(%start: index, %stop: index, %cond: i1, %val: i32) -> (index, i1, i32) {
+  %step = arith.constant 1 : index
+
+  %res:3 = fir.iterate_while (%i = %start to %stop step %step) and (%ok = %cond) iter_args(%x = %val) -> (index, i1, i32) {
+    %new_x = arith.constant 123 : i32
+    %new_ok = arith.constant 1 : i1
+    fir.result %i, %new_ok, %new_x : index, i1, i32
+  }
+
+  return %res#0, %res#1, %res#2 : index, i1, i32
+}
+
+// CHECK-LABEL:   func.func @test_zero_iterations() -> (index, i1, i8) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant true
+// CHECK:           %[[VAL_4:.*]] = arith.constant 42 : i8
+// CHECK:           %[[VAL_5:.*]]:3 = scf.while (%[[VAL_6:.*]] = %[[VAL_0]], %[[VAL_7:.*]] = %[[VAL_3]], %[[VAL_8:.*]] = %[[VAL_4]]) : (index, i1, i8) -> (index, i1, i8) {
+// CHECK:             %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_1]] : index
+// CHECK:             %[[VAL_10:.*]] = arith.andi %[[VAL_9]], %[[VAL_7]] : i1
+// CHECK:             scf.condition(%[[VAL_10]]) %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : index, i1, i8
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: i1, %[[VAL_13:.*]]: i8):
+// CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_2]] : index
+// CHECK:             scf.yield %[[VAL_14]], %[[VAL_12]], %[[VAL_13]] : index, i1, i8
+// CHECK:           }
+// CHECK:           return %[[VAL_15:.*]]#0, %[[VAL_15]]#1, %[[VAL_15]]#2 : index, i1, i8
+// CHECK:         }
+func.func @test_zero_iterations() -> (index, i1, i8) {
+  %lo = arith.constant 10 : index
+  %up = arith.constant 5 : index
+  %step = arith.constant 1 : index
+  %ok = arith.constant 1 : i1
+  %x = arith.constant 42 : i8
+
+  %res:3 = fir.iterate_while (%i = %lo to %up step %step) and (%c = %ok) iter_args(%xv = %x) -> (index, i1, i8) {
+    fir.result %i, %c, %xv : index, i1, i8
+  }
+
+  return %res#0, %res#1, %res#2 : index, i1, i8
+}

@clementval clementval requested a review from rscottmanley August 7, 2025 05:43

auto &afterBlock = *rewriter.createBlock(
&scfWhileOp.getAfter(), scfWhileOp.getAfter().end(), loopTypes,
SmallVector<Location>(loopTypes.size(), loc));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the original block meets the relevant requirements, I think we can use scfWhileOp.getAfter().takeBody instead of creating a new block.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@@ -10,6 +10,7 @@
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/DialectConversion.h"
#include <mlir/Support/LLVM.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use #include "mlir/Support/LLVM.h" and sort the headers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This include is introduced by IDE, got it removed now, thanks 🤝

Copy link
Contributor

@c8ef c8ef left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG.

&scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes,
SmallVector<Location>(loopTypes.size(), loc));

auto beforeArgs = scfWhileOp.getBefore().getArguments();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know what type this is... is it a vector of values? range of block arguments? -- the autos below would be okay if this one was specified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The auto is replaced by explicit type now.


auto beforeArgs = scfWhileOp.getBefore().getArguments();
auto beforeIv = beforeArgs[0];
auto beforeOk = beforeArgs[1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is "beforeOk"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the early exit flag in fir.iterate_while. (the example in document use the name "ok", so I just used it)
It is now renamed to earlyExitInBefore. thanks 🤝

Copy link
Contributor

@NexMing NexMing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 97 to 100
Location loc = iterWhileOp.getLoc();
Value lowerBound = iterWhileOp.getLowerBound();
Value upperBound = iterWhileOp.getUpperBound();
Value step = iterWhileOp.getStep();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a good reason to not follow the same style than all files in flang/lib/Lower? We usually have expanded namespace. Not specific to this PR but it would be nice to follow that same style.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there was not a good reason. The code has been modified to use expanded namespace. Thanks ! 🤝

Copy link

github-actions bot commented Aug 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@NexMing NexMing merged commit c164e63 into llvm:main Aug 14, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants