Skip to content

[mlir][func][bufferization] Fix cast incompatible when bufferize callOp #105929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 26, 2024

Conversation

CoTinker
Copy link
Contributor

Handle caller/callee type mismatch using castOrReallocMemRefValue instead of just a CastOp. The method insert a reallocation + copy if it cannot be statically guaranteed that a direct cast would be valid. Fix #105916.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Aug 24, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 24, 2024

@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

Handle caller/callee type mismatch using castOrReallocMemRefValue instead of just a CastOp. The method insert a reallocation + copy if it cannot be statically guaranteed that a direct cast would be valid. Fix #105916.


Full diff: https://github.com/llvm/llvm-project/pull/105929.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+13-8)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+24)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 053ea7935260a2..f85d0c35c5af33 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -258,20 +258,25 @@ struct CallOpInterface
         return failure();
       Value buffer = *maybeBuffer;
 
-      // Caller / callee type mismatch is handled with a CastOp.
+      // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
       auto memRefType = funcType.getInput(opOperand.getOperandNumber());
       // Since we don't yet have a clear layout story, to_memref may
       // conservatively turn tensors into more dynamic memref than necessary.
       // If the memref type of the callee fails, introduce an extra memref.cast
       // that will either canonicalize away or fail compilation until we can do
-      // something better.
+      // something better. Insert a reallocation + copy if it cannot be
+      // statically guaranteed that a direct cast would be valid.
       if (buffer.getType() != memRefType) {
-        assert(
-            memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
-            "CallOp::bufferize: cast incompatible");
-        Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
-                                                           memRefType, buffer);
-        buffer = castBuffer;
+        auto memrefDestType = dyn_cast<MemRefType>(memRefType);
+        assert(memrefDestType &&
+               "buffer layout not supported on unranked tensors");
+        BufferizationOptions options;
+        options.bufferAlignment = 0;
+        FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
+            rewriter, buffer, memrefDestType, options);
+        if (failed(replacement))
+          return failure();
+        buffer = *replacement;
       }
       newOperands.push_back(buffer);
     }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0248afb11f1672..522cb2c4537ce2 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -71,6 +71,30 @@ func.func @return_extract_slice(%idx: index, %sz: index) -> (tensor<2x?xf32>)
 
 // -----
 
+// CHECK-NO-LAYOUT-MAP-LABEL:   func.func @foo(
+// CHECK-NO-LAYOUT-MAP-SAME:                   %[[VAL_0:.*]]: memref<3x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP:           return %[[VAL_0]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:         }
+func.func @foo(%arg0: tensor<3x8xf16>) -> tensor<3x8xf16> {
+  return %arg0 : tensor<3x8xf16>
+}
+
+// CHECK-NO-LAYOUT-MAP-LABEL:   func.func @call_extract_slice(
+// CHECK-NO-LAYOUT-MAP-SAME:                                  %[[VAL_0:.*]]: memref<4x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_1:.*]] = memref.subview %[[VAL_0]][1, 0] [3, 8] [1, 1] : memref<4x8xf16> to memref<3x8xf16, strided<[8, 1], offset: 8>>
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_2:.*]] = memref.alloc() : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           memref.copy %[[VAL_1]], %[[VAL_2]] : memref<3x8xf16, strided<[8, 1], offset: 8>> to memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<3x8xf16>) -> memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           return %[[VAL_3]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:         }
+func.func @call_extract_slice(%arg0: tensor<4x8xf16>) -> (tensor<3x8xf16>) {
+  %0 = tensor.extract_slice %arg0[1, 0] [3, 8] [1, 1] : tensor<4x8xf16> to tensor<3x8xf16>
+  %1 = call @foo(%0) : (tensor<3x8xf16>) -> tensor<3x8xf16>
+  return %1 : tensor<3x8xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func private @private_func
 // CHECK-NO-LAYOUT-MAP-LABEL: func private @private_func(memref<?xf32>) -> f32
 func.func private @private_func(tensor<?xf32>) -> (f32)

auto memrefDestType = dyn_cast<MemRefType>(memRefType);
assert(memrefDestType &&
"buffer layout not supported on unranked tensors");
BufferizationOptions options;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use options that are passed to this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct to set bufferAlignment to 0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, just take it directly from the given options.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see

Handle caller/callee type mismatch using `castOrReallocMemRefValue` instead
of just a `CastOp`. The method insert a reallocation + copy if it cannot
be statically guaranteed that a direct cast would be valid.
@CoTinker CoTinker merged commit 7f04a8a into llvm:main Aug 26, 2024
8 checks passed
@CoTinker CoTinker deleted the fix_buffer branch August 27, 2024 04:12
paul0403 added a commit to PennyLaneAI/catalyst that referenced this pull request May 20, 2025
[mlir][func][bufferization] Fix cast incompatible when bufferize callOp llvm/llvm-project#105929
llvm/llvm-project@7f04a8a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][bufferization] Assertion `memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && "CallOp::bufferize: cast incompatible"' failed
3 participants