Skip to content

Commit f766cd2

Browse files
committed
Patch fixes for dialect conversion.
1 parent 792b673 commit f766cd2

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -832,11 +832,13 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
832832
Block *insertBlock,
833833
Block::iterator insertPt, Location loc,
834834
ValueRange inputs, Type outputType,
835-
const TypeConverter *converter);
835+
const TypeConverter *converter,
836+
MLIRContext *context);
836837

837838
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
838839
Type outputType,
839-
const TypeConverter *converter);
840+
const TypeConverter *converter,
841+
MLIRContext *context);
840842

841843
//===--------------------------------------------------------------------===//
842844
// Rewriter Notification Hooks
@@ -1185,7 +1187,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11851187
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
11861188
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
11871189
Value castValue = buildUnresolvedTargetMaterialization(
1188-
operandLoc, newOperand, desiredType, currentTypeConverter);
1190+
operandLoc, newOperand, desiredType, currentTypeConverter,
1191+
rewriter.getContext());
11891192
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
11901193
newOperand = castValue;
11911194
}
@@ -1300,7 +1303,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13001303
Value repl = buildUnresolvedMaterialization(
13011304
MaterializationKind::Source, newBlock, newBlock->begin(),
13021305
origArg.getLoc(), /*inputs=*/ValueRange(),
1303-
/*outputType=*/origArgType, converter);
1306+
/*outputType=*/origArgType, converter, rewriter.getContext());
13041307
mapping.map(origArg, repl);
13051308
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13061309
continue;
@@ -1324,7 +1327,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13241327
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
13251328
Value argMat = buildUnresolvedMaterialization(
13261329
MaterializationKind::Argument, newBlock, newBlock->begin(),
1327-
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
1330+
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter,
1331+
rewriter.getContext());
13281332
mapping.map(origArg, argMat);
13291333
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13301334

@@ -1339,7 +1343,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13391343
legalOutputType = converter->convertType(origArgType);
13401344
if (legalOutputType && legalOutputType != origArgType) {
13411345
Value targetMat = buildUnresolvedTargetMaterialization(
1342-
origArg.getLoc(), argMat, legalOutputType, converter);
1346+
origArg.getLoc(), argMat, legalOutputType, converter,
1347+
rewriter.getContext());
13431348
mapping.map(argMat, targetMat);
13441349
}
13451350
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1363,30 +1368,32 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13631368
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13641369
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
13651370
Location loc, ValueRange inputs, Type outputType,
1366-
const TypeConverter *converter) {
1371+
const TypeConverter *converter, MLIRContext *context) {
13671372
// Avoid materializing an unnecessary cast.
13681373
if (inputs.size() == 1 && inputs.front().getType() == outputType)
13691374
return inputs.front();
13701375

13711376
// Create an unresolved materialization. We use a new OpBuilder to avoid
13721377
// tracking the materialization like we do for other operations.
1373-
OpBuilder builder(insertBlock, insertPt);
1378+
// OpBuilder builder(context, insertBlock, insertPt);
1379+
OpBuilder builder(context);
1380+
builder.setInsertionPoint(insertBlock, insertPt);
13741381
auto convertOp =
13751382
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
13761383
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13771384
return convertOp.getResult(0);
13781385
}
13791386
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
1380-
Location loc, Value input, Type outputType,
1381-
const TypeConverter *converter) {
1387+
Location loc, Value input, Type outputType, const TypeConverter *converter,
1388+
MLIRContext *context) {
13821389
Block *insertBlock = input.getParentBlock();
13831390
Block::iterator insertPt = insertBlock->begin();
13841391
if (OpResult inputRes = dyn_cast<OpResult>(input))
13851392
insertPt = ++inputRes.getOwner()->getIterator();
13861393

13871394
return buildUnresolvedMaterialization(MaterializationKind::Target,
13881395
insertBlock, insertPt, loc, input,
1389-
outputType, converter);
1396+
outputType, converter, context);
13901397
}
13911398

13921399
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)