Skip to content

Commit d79cff7

Browse files
committed
[mlir][func][bufferization] Fix cast incompatible when bufferize callOp
Handle caller/callee type mismatch using `castOrReallocMemRefValue` instead of just a `CastOp`. The method insert a reallocation + copy if it cannot be statically guaranteed that a direct cast would be valid.
1 parent 76236fa commit d79cff7

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -258,20 +258,25 @@ struct CallOpInterface
258258
return failure();
259259
Value buffer = *maybeBuffer;
260260

261-
// Caller / callee type mismatch is handled with a CastOp.
261+
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
262262
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
263263
// Since we don't yet have a clear layout story, to_memref may
264264
// conservatively turn tensors into more dynamic memref than necessary.
265265
// If the memref type of the callee fails, introduce an extra memref.cast
266266
// that will either canonicalize away or fail compilation until we can do
267-
// something better.
267+
// something better. Insert a reallocation + copy if it cannot be
268+
// statically guaranteed that a direct cast would be valid.
268269
if (buffer.getType() != memRefType) {
269-
assert(
270-
memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
271-
"CallOp::bufferize: cast incompatible");
272-
Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
273-
memRefType, buffer);
274-
buffer = castBuffer;
270+
auto memrefDestType = dyn_cast<MemRefType>(memRefType);
271+
assert(memrefDestType &&
272+
"buffer layout not supported on unranked tensors");
273+
BufferizationOptions options;
274+
options.bufferAlignment = 0;
275+
FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
276+
rewriter, buffer, memrefDestType, options);
277+
if (failed(replacement))
278+
return failure();
279+
buffer = *replacement;
275280
}
276281
newOperands.push_back(buffer);
277282
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,30 @@ func.func @return_extract_slice(%idx: index, %sz: index) -> (tensor<2x?xf32>)
7171

7272
// -----
7373

74+
// CHECK-NO-LAYOUT-MAP-LABEL: func.func @foo(
75+
// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<3x8xf16>) -> memref<3x8xf16> {
76+
// CHECK-NO-LAYOUT-MAP: return %[[VAL_0]] : memref<3x8xf16>
77+
// CHECK-NO-LAYOUT-MAP: }
78+
func.func @foo(%arg0: tensor<3x8xf16>) -> tensor<3x8xf16> {
79+
return %arg0 : tensor<3x8xf16>
80+
}
81+
82+
// CHECK-NO-LAYOUT-MAP-LABEL: func.func @call_extract_slice(
83+
// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<4x8xf16>) -> memref<3x8xf16> {
84+
// CHECK-NO-LAYOUT-MAP: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][1, 0] [3, 8] [1, 1] : memref<4x8xf16> to memref<3x8xf16, strided<[8, 1], offset: 8>>
85+
// CHECK-NO-LAYOUT-MAP: %[[VAL_2:.*]] = memref.alloc() : memref<3x8xf16>
86+
// CHECK-NO-LAYOUT-MAP: memref.copy %[[VAL_1]], %[[VAL_2]] : memref<3x8xf16, strided<[8, 1], offset: 8>> to memref<3x8xf16>
87+
// CHECK-NO-LAYOUT-MAP: %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<3x8xf16>) -> memref<3x8xf16>
88+
// CHECK-NO-LAYOUT-MAP: return %[[VAL_3]] : memref<3x8xf16>
89+
// CHECK-NO-LAYOUT-MAP: }
90+
func.func @call_extract_slice(%arg0: tensor<4x8xf16>) -> (tensor<3x8xf16>) {
91+
%0 = tensor.extract_slice %arg0[1, 0] [3, 8] [1, 1] : tensor<4x8xf16> to tensor<3x8xf16>
92+
%1 = call @foo(%0) : (tensor<3x8xf16>) -> tensor<3x8xf16>
93+
return %1 : tensor<3x8xf16>
94+
}
95+
96+
// -----
97+
7498
// CHECK-LABEL: func private @private_func
7599
// CHECK-NO-LAYOUT-MAP-LABEL: func private @private_func(memref<?xf32>) -> f32
76100
func.func private @private_func(tensor<?xf32>) -> (f32)

0 commit comments

Comments
 (0)