From 93dfcb8520fae39bede2fafb0823d8ea864aee73 Mon Sep 17 00:00:00 2001 From: "tashuang.zk" Date: Sat, 4 Apr 2020 08:16:32 +0800 Subject: [PATCH] [MLIR] revise IfLowering of LoopToStandard Pass to support IfOp with return values --- .../Conversion/LoopToStandard/LoopToStandard.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp index e72c83027611b..2a256d870a770 100644 --- a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp @@ -244,7 +244,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, // place it before the continuation block, and branch to it. auto &thenRegion = ifOp.thenRegion(); auto *thenBlock = &thenRegion.front(); - rewriter.eraseOp(thenRegion.back().getTerminator()); + auto thenTerminator = thenRegion.back().getTerminator(); + continueBlock->addArguments(thenTerminator->getOperandTypes()); + auto thenYieldValues = thenTerminator->getOperands(); + rewriter.eraseOp(thenTerminator); rewriter.setInsertionPointToEnd(&thenRegion.back()); rewriter.create(loc, continueBlock); rewriter.inlineRegionBefore(thenRegion, continueBlock); @@ -256,9 +259,14 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, auto &elseRegion = ifOp.elseRegion(); if (!elseRegion.empty()) { elseBlock = &elseRegion.front(); - rewriter.eraseOp(elseRegion.back().getTerminator()); + auto elseTerminator = elseRegion.back().getTerminator(); + assert(elseTerminator->getOperandTypes() == + thenTerminator->getOperandTypes() && + "Yield values from thenBlock and elseBlock mismatch"); + auto elseYieldValues = elseTerminator->getOperands(); + rewriter.eraseOp(elseTerminator); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create(loc, continueBlock); + rewriter.create(loc, continueBlock, elseYieldValues); rewriter.inlineRegionBefore(elseRegion, continueBlock); }