@@ -400,7 +400,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
400
400
401
401
// Replace all results with the yielded values.
402
402
auto yieldOp = cast<scf::YieldOp>(getBody ()->getTerminator ());
403
- rewriter.replaceAllUsesWith (getResults (), yieldOp. getOperands ());
403
+ rewriter.replaceAllUsesWith (getResults (), getYieldedValues ());
404
404
405
405
// Replace block arguments with lower bound (replacement for IV) and
406
406
// iter_args.
@@ -772,27 +772,26 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
772
772
LogicalResult matchAndRewrite (scf::ForOp forOp,
773
773
PatternRewriter &rewriter) const final {
774
774
bool canonicalize = false ;
775
- Block &block = forOp.getRegion ().front ();
776
- auto yieldOp = cast<scf::YieldOp>(block.getTerminator ());
777
775
778
776
// An internal flat vector of block transfer
779
777
// arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
780
778
// transformed block argument mappings. This plays the role of a
781
779
// IRMapping for the particular use case of calling into
782
780
// `inlineBlockBefore`.
781
+ int64_t numResults = forOp.getNumResults ();
783
782
SmallVector<bool , 4 > keepMask;
784
- keepMask.reserve (yieldOp. getNumOperands () );
783
+ keepMask.reserve (numResults );
785
784
SmallVector<Value, 4 > newBlockTransferArgs, newIterArgs, newYieldValues,
786
785
newResultValues;
787
- newBlockTransferArgs.reserve (1 + forOp. getInitArgs (). size () );
786
+ newBlockTransferArgs.reserve (1 + numResults );
788
787
newBlockTransferArgs.push_back (Value ()); // iv placeholder with null value
789
788
newIterArgs.reserve (forOp.getInitArgs ().size ());
790
- newYieldValues.reserve (yieldOp. getNumOperands () );
791
- newResultValues.reserve (forOp. getNumResults () );
789
+ newYieldValues.reserve (numResults );
790
+ newResultValues.reserve (numResults );
792
791
for (auto it : llvm::zip (forOp.getInitArgs (), // iter from outside
793
792
forOp.getRegionIterArgs (), // iter inside region
794
793
forOp.getResults (), // op results
795
- yieldOp. getOperands () // iter yield
794
+ forOp. getYieldedValues () // iter yield
796
795
)) {
797
796
// Forwarded is `true` when:
798
797
// 1) The region `iter` argument is yielded.
@@ -946,12 +945,10 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
946
945
return failure ();
947
946
// If the loop is empty, iterates at least once, and only returns values
948
947
// defined outside of the loop, remove it and replace it with yield values.
949
- auto yieldOp = cast<scf::YieldOp>(block.getTerminator ());
950
- auto yieldOperands = yieldOp.getOperands ();
951
- if (llvm::any_of (yieldOperands,
948
+ if (llvm::any_of (op.getYieldedValues (),
952
949
[&](Value v) { return !op.isDefinedOutsideOfLoop (v); }))
953
950
return failure ();
954
- rewriter.replaceOp (op, yieldOperands );
951
+ rewriter.replaceOp (op, op. getYieldedValues () );
955
952
return success ();
956
953
}
957
954
};
@@ -1224,6 +1221,10 @@ std::optional<APInt> ForOp::getConstantStep() {
1224
1221
return {};
1225
1222
}
1226
1223
1224
+ ValueRange ForOp::getYieldedValues () {
1225
+ return cast<scf::YieldOp>(getBody ()->getTerminator ()).getResults ();
1226
+ }
1227
+
1227
1228
Speculation::Speculatability ForOp::getSpeculatability () {
1228
1229
// `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1229
1230
// and End.
@@ -3205,6 +3206,8 @@ YieldOp WhileOp::getYieldOp() {
3205
3206
return cast<YieldOp>(getAfterBody ()->getTerminator ());
3206
3207
}
3207
3208
3209
+ ValueRange WhileOp::getYieldedValues () { return getYieldOp ().getResults (); }
3210
+
3208
3211
Block::BlockArgListType WhileOp::getBeforeArguments () {
3209
3212
return getBeforeBody ()->getArguments ();
3210
3213
}
0 commit comments