Skip to content

Commit cfea4ad

Browse files
fix test
1 parent 55b95a7 commit cfea4ad

File tree

9 files changed

+130
-54
lines changed

9 files changed

+130
-54
lines changed

mlir/docs/DialectConversion.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ class TypeConverter {
352352
353353
/// This method registers a materialization that will be called when
354354
/// converting (potentially multiple) block arguments that were the result of
355-
/// a signature conversion of a single block argument, to a single SSA value.
355+
/// a signature conversion of a single block argument, to a single SSA value
356+
/// with the old argument type.
356357
template <typename FnT,
357358
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
358359
void addArgumentMaterialization(FnT &&callback) {

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
3838
let assemblyFormat = "attr-dict";
3939
}
4040

41+
def ApplySCFToControlFlowPatternsOp : Op<Transform_Dialect,
42+
"apply_conversion_patterns.scf.scf_to_control_flow",
43+
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
44+
let description = [{
45+
Collects patterns that lower structured control flow ops to unstructured
46+
control flow.
47+
}];
48+
49+
let assemblyFormat = "attr-dict";
50+
}
51+
4152
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
4253

4354
def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ class TypeConverter {
182182

183183
/// This method registers a materialization that will be called when
184184
/// converting (potentially multiple) block arguments that were the result of
185-
/// a signature conversion of a single block argument, to a single SSA value.
185+
/// a signature conversion of a single block argument, to a single SSA value
186+
/// with the old block argument type.
186187
template <typename FnT, typename T = typename llvm::function_traits<
187188
std::decay_t<FnT>>::template arg_t<1>>
188189
void addArgumentMaterialization(FnT &&callback) {

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156-
// Materialization for memrefs creates descriptor structs from individual
157-
// values constituting them, when descriptors are used, i.e. more than one
158-
// value represents a memref.
156+
// Argument materializations convert from the new block argument types
157+
// (multiple SSA values that make up a memref descriptor) back to the
158+
// original block argument type. The dialect conversion framework will then
159+
// insert a target materialization from the original block argument type to
160+
// a legal type.
159161
addArgumentMaterialization(
160162
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
161163
Location loc) -> std::optional<Value> {
@@ -164,12 +166,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
164166
// memref descriptor cannot be built just from a bare pointer.
165167
return std::nullopt;
166168
}
167-
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
168-
inputs);
169+
Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
170+
resultType, inputs);
171+
// An argument materialization must return a value of type
172+
// `resultType`, so insert a cast from the memref descriptor type
173+
// (!llvm.struct) to the original memref type.
174+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175+
.getResult(0);
169176
});
170177
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
171178
ValueRange inputs,
172179
Location loc) -> std::optional<Value> {
180+
Value desc;
173181
if (inputs.size() == 1) {
174182
// This is a bare pointer. We allow bare pointers only for function entry
175183
// blocks.
@@ -180,10 +188,16 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
180188
if (!block->isEntryBlock() ||
181189
!isa<FunctionOpInterface>(block->getParentOp()))
182190
return std::nullopt;
183-
return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
191+
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
184192
inputs[0]);
193+
} else {
194+
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
185195
}
186-
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
196+
// An argument materialization must return a value of type `resultType`,
197+
// so insert a cast from the memref descriptor type (!llvm.struct) to the
198+
// original memref type.
199+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
200+
.getResult(0);
187201
});
188202
// Add generic source and target materializations to handle cases where
189203
// non-LLVM types persist after an LLVM conversion.

mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps
1313
MLIRIR
1414
MLIRLoopLikeInterface
1515
MLIRSCFDialect
16+
MLIRSCFToControlFlow
1617
MLIRSCFTransforms
1718
MLIRSCFUtils
1819
MLIRTransformDialect

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
10+
11+
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
1012
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1113
#include "mlir/Dialect/Affine/LoopUtils.h"
1214
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -49,6 +51,11 @@ void transform::ApplySCFStructuralConversionPatternsOp::
4951
conversionTarget);
5052
}
5153

54+
void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
55+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
56+
populateSCFToControlFlowConversionPatterns(patterns);
57+
}
58+
5259
//===----------------------------------------------------------------------===//
5360
// ForallToForOp
5461
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
707707
UnresolvedMaterializationRewrite(
708708
ConversionPatternRewriterImpl &rewriterImpl,
709709
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
710-
MaterializationKind kind = MaterializationKind::Target,
711-
Type origOutputType = nullptr)
710+
MaterializationKind kind = MaterializationKind::Target)
712711
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
713-
converterAndKind(converter, kind), origOutputType(origOutputType) {}
712+
converterAndKind(converter, kind) {}
714713

715714
static bool classof(const IRRewrite *rewrite) {
716715
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
734733
return converterAndKind.getInt();
735734
}
736735

737-
/// Return the original illegal output type of the input values.
738-
Type getOrigOutputType() const { return origOutputType; }
739-
740736
private:
741737
/// The corresponding type converter to use when resolving this
742738
/// materialization, and the kind of this materialization.
743739
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
744740
converterAndKind;
745-
746-
/// The original output type. This is only used for argument conversions.
747-
Type origOutputType;
748741
};
749742
} // namespace
750743

@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
860853
Block *insertBlock,
861854
Block::iterator insertPt, Location loc,
862855
ValueRange inputs, Type outputType,
863-
Type origOutputType,
864856
const TypeConverter *converter);
865857

866858
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
867859
ValueRange inputs,
868-
Type origOutputType,
869860
Type outputType,
870861
const TypeConverter *converter);
871862

@@ -1388,20 +1379,24 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13881379
if (replArgs.size() == 1 &&
13891380
(!converter || replArgs[0].getType() == origArg.getType())) {
13901381
newArg = replArgs.front();
1382+
mapping.map(origArg, newArg);
13911383
} else {
1392-
Type origOutputType = origArg.getType();
1393-
1394-
// Legalize the argument output type.
1395-
Type outputType = origOutputType;
1396-
if (Type legalOutputType = converter->convertType(outputType))
1397-
outputType = legalOutputType;
1398-
1399-
newArg = buildUnresolvedArgumentMaterialization(
1400-
newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
1401-
converter);
1384+
// Build argument materialization: new block arguments -> old block
1385+
// argument type.
1386+
Value argMat = buildUnresolvedArgumentMaterialization(
1387+
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
1388+
mapping.map(origArg, argMat);
1389+
1390+
// Build target materialization: old block argument type -> legal type.
1391+
if (Type legalOutputType = converter->convertType(origArg.getType())) {
1392+
newArg = buildUnresolvedTargetMaterialization(
1393+
origArg.getLoc(), argMat, legalOutputType, converter);
1394+
mapping.map(argMat, newArg);
1395+
} else {
1396+
newArg = argMat;
1397+
}
14021398
}
14031399

1404-
mapping.map(origArg, newArg);
14051400
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
14061401
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
14071402
}
@@ -1424,7 +1419,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14241419
/// of input operands.
14251420
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14261421
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1427-
Location loc, ValueRange inputs, Type outputType, Type origOutputType,
1422+
Location loc, ValueRange inputs, Type outputType,
14281423
const TypeConverter *converter) {
14291424
// Avoid materializing an unnecessary cast.
14301425
if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1435,16 +1430,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14351430
OpBuilder builder(insertBlock, insertPt);
14361431
auto convertOp =
14371432
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1438-
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1439-
origOutputType);
1433+
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
14401434
return convertOp.getResult(0);
14411435
}
14421436
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
1443-
Block *block, Location loc, ValueRange inputs, Type origOutputType,
1444-
Type outputType, const TypeConverter *converter) {
1437+
Block *block, Location loc, ValueRange inputs, Type outputType,
1438+
const TypeConverter *converter) {
14451439
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
14461440
block->begin(), loc, inputs, outputType,
1447-
origOutputType, converter);
1441+
converter);
14481442
}
14491443
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14501444
Location loc, Value input, Type outputType,
@@ -1456,7 +1450,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14561450

14571451
return buildUnresolvedMaterialization(MaterializationKind::Target,
14581452
insertBlock, insertPt, loc, input,
1459-
outputType, outputType, converter);
1453+
outputType, converter);
14601454
}
14611455

14621456
//===----------------------------------------------------------------------===//
@@ -2672,19 +2666,28 @@ static void computeNecessaryMaterializations(
26722666
ConversionPatternRewriterImpl &rewriterImpl,
26732667
DenseMap<Value, SmallVector<Value>> &inverseMapping,
26742668
SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2669+
// Helper function to check if the given value or a not yet materialized
2670+
// replacement of the given value is live.
2671+
// Note: `inverseMapping` maps from replaced values to original values.
26752672
auto isLive = [&](Value value) {
26762673
auto findFn = [&](Operation *user) {
26772674
auto matIt = materializationOps.find(user);
26782675
if (matIt != materializationOps.end())
26792676
return !necessaryMaterializations.count(matIt->second);
26802677
return rewriterImpl.isOpIgnored(user);
26812678
};
2682-
// This value may be replacing another value that has a live user.
2683-
for (Value inv : inverseMapping.lookup(value))
2684-
if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2679+
// A worklist is needed because a value may have gone through a chain of
2680+
// replacements and each of the replaced values may have live users.
2681+
SmallVector<Value> worklist;
2682+
worklist.push_back(value);
2683+
while (!worklist.empty()) {
2684+
Value next = worklist.pop_back_val();
2685+
if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
26852686
return true;
2686-
// Or have live users itself.
2687-
return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
2687+
// This value may be replacing another value that has a live user.
2688+
llvm::append_range(worklist, inverseMapping.lookup(next));
2689+
}
2690+
return false;
26882691
};
26892692

26902693
llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
@@ -2844,18 +2847,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
28442847
switch (mat.getMaterializationKind()) {
28452848
case MaterializationKind::Argument:
28462849
// Try to materialize an argument conversion.
2847-
// FIXME: The current argument materialization hook expects the original
2848-
// output type, even though it doesn't use that as the actual output type
2849-
// of the generated IR. The output type is just used as an indicator of
2850-
// the type of materialization to do. This behavior is really awkward in
2851-
// that it diverges from the behavior of the other hooks, and can be
2852-
// easily misunderstood. We should clean up the argument hooks to better
2853-
// represent the desired invariants we actually care about.
28542850
newMaterialization = converter->materializeArgumentConversion(
2855-
rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
2851+
rewriter, op->getLoc(), outputType, inputOperands);
28562852
if (newMaterialization)
28572853
break;
2858-
28592854
// If an argument materialization failed, fallback to trying a target
28602855
// materialization.
28612856
[[fallthrough]];
@@ -2865,6 +2860,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
28652860
break;
28662861
}
28672862
if (newMaterialization) {
2863+
assert(newMaterialization.getType() == opResult.getType() &&
2864+
"materialization callback produced value of incorrect type");
28682865
replaceMaterialization(rewriterImpl, opResult, newMaterialization,
28692866
inverseMapping);
28702867
return success();

mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: mlir-opt -convert-func-to-llvm -reconcile-unrealized-casts %s | FileCheck %s
22

3-
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' %s | FileCheck %s --check-prefix=BAREPTR
3+
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
44

5-
// RUN: mlir-opt -transform-interpreter %s | FileCheck %s --check-prefix=BAREPTR
5+
// RUN: mlir-opt -transform-interpreter -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
66

77
// These tests were separated from func-memref.mlir because applying
88
// -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
2+
3+
// CHECK-LABEL: func @complex_block_signature_conversion(
4+
// CHECK: %[[cst:.*]] = complex.constant
5+
// CHECK: %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex<f64> to !llvm.struct<(f64, f64)>
6+
// Note: Some blocks are omitted.
7+
// CHECK: llvm.br ^[[block1:.*]](%[[complex_llvm]]
8+
// CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>):
9+
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex<f64>
10+
// CHECK: llvm.br ^[[block2:.*]]
11+
// CHECK: ^[[block2]]:
12+
// CHECK: "test.consumer_of_complex"(%[[cast]]) : (complex<f64>) -> ()
13+
func.func @complex_block_signature_conversion() {
14+
%cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
15+
%true = arith.constant true
16+
%0 = scf.if %true -> complex<f64> {
17+
scf.yield %cst : complex<f64>
18+
} else {
19+
scf.yield %cst : complex<f64>
20+
}
21+
22+
// Regression test to ensure that the a source materialization is inserted.
23+
// The operand of "test.consumer_of_complex" must not change.
24+
"test.consumer_of_complex"(%0) : (complex<f64>) -> ()
25+
return
26+
}
27+
28+
module attributes {transform.with_named_sequence} {
29+
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
30+
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
31+
: (!transform.any_op) -> !transform.any_op
32+
transform.apply_conversion_patterns to %func {
33+
transform.apply_conversion_patterns.dialect_to_llvm "cf"
34+
transform.apply_conversion_patterns.func.func_to_llvm
35+
transform.apply_conversion_patterns.scf.scf_to_control_flow
36+
} with type_converter {
37+
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
38+
} {
39+
legal_dialects = ["llvm"],
40+
partial_conversion
41+
} : !transform.any_op
42+
transform.yield
43+
}
44+
}

0 commit comments

Comments
 (0)