-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][func][bufferization] Fix cast incompatible when bufferize callOp #105929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesHandle caller/callee type mismatch using Full diff: https://github.com/llvm/llvm-project/pull/105929.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 053ea7935260a2..f85d0c35c5af33 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -258,20 +258,25 @@ struct CallOpInterface
return failure();
Value buffer = *maybeBuffer;
- // Caller / callee type mismatch is handled with a CastOp.
+ // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
// that will either canonicalize away or fail compilation until we can do
- // something better.
+ // something better. Insert a reallocation + copy if it cannot be
+ // statically guaranteed that a direct cast would be valid.
if (buffer.getType() != memRefType) {
- assert(
- memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
- "CallOp::bufferize: cast incompatible");
- Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
- memRefType, buffer);
- buffer = castBuffer;
+ auto memrefDestType = dyn_cast<MemRefType>(memRefType);
+ assert(memrefDestType &&
+ "buffer layout not supported on unranked tensors");
+ BufferizationOptions options;
+ options.bufferAlignment = 0;
+ FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
+ rewriter, buffer, memrefDestType, options);
+ if (failed(replacement))
+ return failure();
+ buffer = *replacement;
}
newOperands.push_back(buffer);
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0248afb11f1672..522cb2c4537ce2 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -71,6 +71,30 @@ func.func @return_extract_slice(%idx: index, %sz: index) -> (tensor<2x?xf32>)
// -----
+// CHECK-NO-LAYOUT-MAP-LABEL: func.func @foo(
+// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<3x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP: return %[[VAL_0]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP: }
+func.func @foo(%arg0: tensor<3x8xf16>) -> tensor<3x8xf16> {
+ return %arg0 : tensor<3x8xf16>
+}
+
+// CHECK-NO-LAYOUT-MAP-LABEL: func.func @call_extract_slice(
+// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<4x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][1, 0] [3, 8] [1, 1] : memref<4x8xf16> to memref<3x8xf16, strided<[8, 1], offset: 8>>
+// CHECK-NO-LAYOUT-MAP: %[[VAL_2:.*]] = memref.alloc() : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP: memref.copy %[[VAL_1]], %[[VAL_2]] : memref<3x8xf16, strided<[8, 1], offset: 8>> to memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP: %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<3x8xf16>) -> memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP: return %[[VAL_3]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP: }
+func.func @call_extract_slice(%arg0: tensor<4x8xf16>) -> (tensor<3x8xf16>) {
+ %0 = tensor.extract_slice %arg0[1, 0] [3, 8] [1, 1] : tensor<4x8xf16> to tensor<3x8xf16>
+ %1 = call @foo(%0) : (tensor<3x8xf16>) -> tensor<3x8xf16>
+ return %1 : tensor<3x8xf16>
+}
+
+// -----
+
// CHECK-LABEL: func private @private_func
// CHECK-NO-LAYOUT-MAP-LABEL: func private @private_func(memref<?xf32>) -> f32
func.func private @private_func(tensor<?xf32>) -> (f32)
|
auto memrefDestType = dyn_cast<MemRefType>(memRefType); | ||
assert(memrefDestType && | ||
"buffer layout not supported on unranked tensors"); | ||
BufferizationOptions options; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use options
that are passed to this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it correct to set bufferAlignment to 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, just take it directly from the given options
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see
Handle caller/callee type mismatch using `castOrReallocMemRefValue` instead of just a `CastOp`. The method insert a reallocation + copy if it cannot be statically guaranteed that a direct cast would be valid.
[mlir][func][bufferization] Fix cast incompatible when bufferize callOp llvm/llvm-project#105929 llvm/llvm-project@7f04a8a
Handle caller/callee type mismatch using
castOrReallocMemRefValue
instead of just aCastOp
. The method insert a reallocation + copy if it cannot be statically guaranteed that a direct cast would be valid. Fix #105916.