Skip to content

Commit cbbf741

Browse files
fix test
1 parent e73cf2f commit cbbf741

File tree

9 files changed

+141
-60
lines changed

9 files changed

+141
-60
lines changed

mlir/docs/DialectConversion.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ class TypeConverter {
352352
353353
/// This method registers a materialization that will be called when
354354
/// converting (potentially multiple) block arguments that were the result of
355-
/// a signature conversion of a single block argument, to a single SSA value.
355+
/// a signature conversion of a single block argument, to a single SSA value
356+
/// with the old argument type.
356357
template <typename FnT,
357358
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
358359
void addArgumentMaterialization(FnT &&callback) {

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
3838
let assemblyFormat = "attr-dict";
3939
}
4040

41+
def ApplySCFToControlFlowPatternsOp : Op<Transform_Dialect,
42+
"apply_conversion_patterns.scf.scf_to_control_flow",
43+
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
44+
let description = [{
45+
Collects patterns that lower structured control flow ops to unstructured
46+
control flow.
47+
}];
48+
49+
let assemblyFormat = "attr-dict";
50+
}
51+
4152
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
4253

4354
def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,15 @@ class TypeConverter {
174174
/// where `T` is any subclass of `Type`. This function is responsible for
175175
/// creating an operation, using the OpBuilder and Location provided, that
176176
/// "casts" a range of values into a single value of the given type `T`. It
177-
/// must return a Value of the converted type on success, an `std::nullopt` if
177+
/// must return a Value of the type `T` on success, an `std::nullopt` if
178178
/// it failed but other materialization can be attempted, and `nullptr` on
179-
/// unrecoverable failure. It will only be called for (sub)types of `T`.
180-
/// Materialization functions must be provided when a type conversion may
181-
/// persist after the conversion has finished.
179+
/// unrecoverable failure. Materialization functions must be provided when a
180+
/// type conversion may persist after the conversion has finished.
182181

183182
/// This method registers a materialization that will be called when
184183
/// converting (potentially multiple) block arguments that were the result of
185-
/// a signature conversion of a single block argument, to a single SSA value.
184+
/// a signature conversion of a single block argument, to a single SSA value
185+
/// with the old block argument type.
186186
template <typename FnT, typename T = typename llvm::function_traits<
187187
std::decay_t<FnT>>::template arg_t<1>>
188188
void addArgumentMaterialization(FnT &&callback) {

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156-
// Materialization for memrefs creates descriptor structs from individual
157-
// values constituting them, when descriptors are used, i.e. more than one
158-
// value represents a memref.
156+
// Argument materializations convert from the new block argument types
157+
// (multiple SSA values that make up a memref descriptor) back to the
158+
// original block argument type. The dialect conversion framework will then
159+
// insert a target materialization from the original block argument type to
160+
// a legal type.
159161
addArgumentMaterialization(
160162
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
161163
Location loc) -> std::optional<Value> {
@@ -164,12 +166,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
164166
// memref descriptor cannot be built just from a bare pointer.
165167
return std::nullopt;
166168
}
167-
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
168-
inputs);
169+
Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
170+
resultType, inputs);
171+
// An argument materialization must return a value of type
172+
// `resultType`, so insert a cast from the memref descriptor type
173+
// (!llvm.struct) to the original memref type.
174+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175+
.getResult(0);
169176
});
170177
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
171178
ValueRange inputs,
172179
Location loc) -> std::optional<Value> {
180+
Value desc;
173181
if (inputs.size() == 1) {
174182
// This is a bare pointer. We allow bare pointers only for function entry
175183
// blocks.
@@ -180,10 +188,16 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
180188
if (!block->isEntryBlock() ||
181189
!isa<FunctionOpInterface>(block->getParentOp()))
182190
return std::nullopt;
183-
return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
191+
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
184192
inputs[0]);
193+
} else {
194+
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
185195
}
186-
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
196+
// An argument materialization must return a value of type `resultType`,
197+
// so insert a cast from the memref descriptor type (!llvm.struct) to the
198+
// original memref type.
199+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
200+
.getResult(0);
187201
});
188202
// Add generic source and target materializations to handle cases where
189203
// non-LLVM types persist after an LLVM conversion.

mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps
1313
MLIRIR
1414
MLIRLoopLikeInterface
1515
MLIRSCFDialect
16+
MLIRSCFToControlFlow
1617
MLIRSCFTransforms
1718
MLIRSCFUtils
1819
MLIRTransformDialect

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
10+
11+
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
1012
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1113
#include "mlir/Dialect/Affine/LoopUtils.h"
1214
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -49,6 +51,11 @@ void transform::ApplySCFStructuralConversionPatternsOp::
4951
conversionTarget);
5052
}
5153

54+
void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
55+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
56+
populateSCFToControlFlowConversionPatterns(patterns);
57+
}
58+
5259
//===----------------------------------------------------------------------===//
5360
// ForallToForOp
5461
//===----------------------------------------------------------------------===//
@@ -261,8 +268,10 @@ loopScheduling(scf::ForOp forOp,
261268
return 1;
262269
};
263270

264-
std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
265-
std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
271+
std::optional<int64_t> ubConstant =
272+
getConstantIntValue(forOp.getUpperBound());
273+
std::optional<int64_t> lbConstant =
274+
getConstantIntValue(forOp.getLowerBound());
266275
DenseMap<Operation *, unsigned> opCycles;
267276
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
268277
for (Operation &op : forOp.getBody()->getOperations()) {

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
707707
UnresolvedMaterializationRewrite(
708708
ConversionPatternRewriterImpl &rewriterImpl,
709709
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
710-
MaterializationKind kind = MaterializationKind::Target,
711-
Type origOutputType = nullptr)
710+
MaterializationKind kind = MaterializationKind::Target)
712711
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
713-
converterAndKind(converter, kind), origOutputType(origOutputType) {}
712+
converterAndKind(converter, kind) {}
714713

715714
static bool classof(const IRRewrite *rewrite) {
716715
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
734733
return converterAndKind.getInt();
735734
}
736735

737-
/// Return the original illegal output type of the input values.
738-
Type getOrigOutputType() const { return origOutputType; }
739-
740736
private:
741737
/// The corresponding type converter to use when resolving this
742738
/// materialization, and the kind of this materialization.
743739
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
744740
converterAndKind;
745-
746-
/// The original output type. This is only used for argument conversions.
747-
Type origOutputType;
748741
};
749742
} // namespace
750743

@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
860853
Block *insertBlock,
861854
Block::iterator insertPt, Location loc,
862855
ValueRange inputs, Type outputType,
863-
Type origOutputType,
864856
const TypeConverter *converter);
865857

866858
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
867859
ValueRange inputs,
868-
Type origOutputType,
869860
Type outputType,
870861
const TypeConverter *converter);
871862

@@ -1388,20 +1379,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13881379
if (replArgs.size() == 1 &&
13891380
(!converter || replArgs[0].getType() == origArg.getType())) {
13901381
newArg = replArgs.front();
1382+
mapping.map(origArg, newArg);
13911383
} else {
1392-
Type origOutputType = origArg.getType();
1393-
1394-
// Legalize the argument output type.
1395-
Type outputType = origOutputType;
1396-
if (Type legalOutputType = converter->convertType(outputType))
1397-
outputType = legalOutputType;
1398-
1399-
newArg = buildUnresolvedArgumentMaterialization(
1400-
newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
1401-
converter);
1384+
// Build argument materialization: new block arguments -> old block
1385+
// argument type.
1386+
Value argMat = buildUnresolvedArgumentMaterialization(
1387+
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
1388+
mapping.map(origArg, argMat);
1389+
1390+
// Build target materialization: old block argument type -> legal type.
1391+
// Note: This function returns an "empty" type if no valid conversion to
1392+
// a legal type exists. In that case, we continue the conversion with the
1393+
// original block argument type.
1394+
Type legalOutputType = converter->convertType(origArg.getType());
1395+
if (legalOutputType && legalOutputType != origArg.getType()) {
1396+
newArg = buildUnresolvedTargetMaterialization(
1397+
origArg.getLoc(), argMat, legalOutputType, converter);
1398+
mapping.map(argMat, newArg);
1399+
} else {
1400+
newArg = argMat;
1401+
}
14021402
}
14031403

1404-
mapping.map(origArg, newArg);
14051404
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
14061405
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
14071406
}
@@ -1424,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14241423
/// of input operands.
14251424
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14261425
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1427-
Location loc, ValueRange inputs, Type outputType, Type origOutputType,
1426+
Location loc, ValueRange inputs, Type outputType,
14281427
const TypeConverter *converter) {
14291428
// Avoid materializing an unnecessary cast.
14301429
if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1435,16 +1434,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14351434
OpBuilder builder(insertBlock, insertPt);
14361435
auto convertOp =
14371436
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1438-
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1439-
origOutputType);
1437+
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
14401438
return convertOp.getResult(0);
14411439
}
14421440
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
1443-
Block *block, Location loc, ValueRange inputs, Type origOutputType,
1444-
Type outputType, const TypeConverter *converter) {
1441+
Block *block, Location loc, ValueRange inputs, Type outputType,
1442+
const TypeConverter *converter) {
14451443
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
14461444
block->begin(), loc, inputs, outputType,
1447-
origOutputType, converter);
1445+
converter);
14481446
}
14491447
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14501448
Location loc, Value input, Type outputType,
@@ -1456,7 +1454,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14561454

14571455
return buildUnresolvedMaterialization(MaterializationKind::Target,
14581456
insertBlock, insertPt, loc, input,
1459-
outputType, outputType, converter);
1457+
outputType, converter);
14601458
}
14611459

14621460
//===----------------------------------------------------------------------===//
@@ -2672,19 +2670,28 @@ static void computeNecessaryMaterializations(
26722670
ConversionPatternRewriterImpl &rewriterImpl,
26732671
DenseMap<Value, SmallVector<Value>> &inverseMapping,
26742672
SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2673+
// Helper function to check if the given value or a not yet materialized
2674+
// replacement of the given value is live.
2675+
// Note: `inverseMapping` maps from replaced values to original values.
26752676
auto isLive = [&](Value value) {
26762677
auto findFn = [&](Operation *user) {
26772678
auto matIt = materializationOps.find(user);
26782679
if (matIt != materializationOps.end())
26792680
return !necessaryMaterializations.count(matIt->second);
26802681
return rewriterImpl.isOpIgnored(user);
26812682
};
2682-
// This value may be replacing another value that has a live user.
2683-
for (Value inv : inverseMapping.lookup(value))
2684-
if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2683+
// A worklist is needed because a value may have gone through a chain of
2684+
// replacements and each of the replaced values may have live users.
2685+
SmallVector<Value> worklist;
2686+
worklist.push_back(value);
2687+
while (!worklist.empty()) {
2688+
Value next = worklist.pop_back_val();
2689+
if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
26852690
return true;
2686-
// Or have live users itself.
2687-
return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
2691+
// This value may be replacing another value that has a live user.
2692+
llvm::append_range(worklist, inverseMapping.lookup(next));
2693+
}
2694+
return false;
26882695
};
26892696

26902697
llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
@@ -2844,18 +2851,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
28442851
switch (mat.getMaterializationKind()) {
28452852
case MaterializationKind::Argument:
28462853
// Try to materialize an argument conversion.
2847-
// FIXME: The current argument materialization hook expects the original
2848-
// output type, even though it doesn't use that as the actual output type
2849-
// of the generated IR. The output type is just used as an indicator of
2850-
// the type of materialization to do. This behavior is really awkward in
2851-
// that it diverges from the behavior of the other hooks, and can be
2852-
// easily misunderstood. We should clean up the argument hooks to better
2853-
// represent the desired invariants we actually care about.
28542854
newMaterialization = converter->materializeArgumentConversion(
2855-
rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
2855+
rewriter, op->getLoc(), outputType, inputOperands);
28562856
if (newMaterialization)
28572857
break;
2858-
28592858
// If an argument materialization failed, fallback to trying a target
28602859
// materialization.
28612860
[[fallthrough]];
@@ -2865,6 +2864,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
28652864
break;
28662865
}
28672866
if (newMaterialization) {
2867+
assert(newMaterialization.getType() == outputType &&
2868+
"materialization callback produced value of incorrect type");
28682869
replaceMaterialization(rewriterImpl, opResult, newMaterialization,
28692870
inverseMapping);
28702871
return success();

mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: mlir-opt -convert-func-to-llvm -reconcile-unrealized-casts %s | FileCheck %s
22

3-
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' %s | FileCheck %s --check-prefix=BAREPTR
3+
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
44

5-
// RUN: mlir-opt -transform-interpreter %s | FileCheck %s --check-prefix=BAREPTR
5+
// RUN: mlir-opt -transform-interpreter -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
66

77
// These tests were separated from func-memref.mlir because applying
88
// -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
2+
3+
// CHECK-LABEL: func @complex_block_signature_conversion(
4+
// CHECK: %[[cst:.*]] = complex.constant
5+
// CHECK: %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex<f64> to !llvm.struct<(f64, f64)>
6+
// Note: Some blocks are omitted.
7+
// CHECK: llvm.br ^[[block1:.*]](%[[complex_llvm]]
8+
// CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>):
9+
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex<f64>
10+
// CHECK: llvm.br ^[[block2:.*]]
11+
// CHECK: ^[[block2]]:
12+
// CHECK: "test.consumer_of_complex"(%[[cast]]) : (complex<f64>) -> ()
13+
func.func @complex_block_signature_conversion() {
14+
%cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
15+
%true = arith.constant true
16+
%0 = scf.if %true -> complex<f64> {
17+
scf.yield %cst : complex<f64>
18+
} else {
19+
scf.yield %cst : complex<f64>
20+
}
21+
22+
// Regression test to ensure that the a source materialization is inserted.
23+
// The operand of "test.consumer_of_complex" must not change.
24+
"test.consumer_of_complex"(%0) : (complex<f64>) -> ()
25+
return
26+
}
27+
28+
module attributes {transform.with_named_sequence} {
29+
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
30+
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
31+
: (!transform.any_op) -> !transform.any_op
32+
transform.apply_conversion_patterns to %func {
33+
transform.apply_conversion_patterns.dialect_to_llvm "cf"
34+
transform.apply_conversion_patterns.func.func_to_llvm
35+
transform.apply_conversion_patterns.scf.scf_to_control_flow
36+
} with type_converter {
37+
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
38+
} {
39+
legal_dialects = ["llvm"],
40+
partial_conversion
41+
} : !transform.any_op
42+
transform.yield
43+
}
44+
}

0 commit comments

Comments
 (0)