From 548a4eab0b2e1d9c22567023b29e212d83198714 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Fri, 2 Feb 2024 16:51:38 -0800 Subject: [PATCH 1/2] [mlir][sparse] refine sparse assembler strategy Rewrite *all* public methods, making original internal, private methods, and exposing wrappers under the original name. This works a bit better in practice (when combined with c-interface mechanism of torch-mlir for example). --- .../Dialect/SparseTensor/Transforms/Passes.td | 2 +- .../Transforms/SparseAssembler.cpp | 55 ++++++++++--------- mlir/test/Dialect/SparseTensor/external.mlir | 49 +++++++++-------- .../Dialect/SparseTensor/torch_linalg.mlir | 55 +++++++++++++++++++ 4 files changed, 113 insertions(+), 48 deletions(-) create mode 100644 mlir/test/Dialect/SparseTensor/torch_linalg.mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 8772d5f127949..58e2d6f32386c 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -15,7 +15,7 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> { let summary = "Add [dis]assemble operations on external sparse tensors"; let description = [{ A pass that converts public entry methods that use sparse tensors as - input parameters and/or output return values into wrapper functions + input parameters and/or output return values into wrapper methods that [dis]assemble the individual tensors that constitute the actual storage used externally into MLIR sparse tensors. This pass can be used to prepare the public entry methods of a program that is compiled by the diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp index f9b6397e0f086..b4cefec8fb21f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp @@ -132,29 +132,29 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types, namespace { // A rewriting rules that converts public entry methods that use sparse tensors -// as input parameters and/or output return values into wrapper functions -// that [dis]assemble the individual tensors that constitute the actual -// storage used externally into MLIR sparse tensors. +// as input parameters and/or output return values into wrapper methods that +// [dis]assemble the individual tensors that constitute the actual storage used +// externally into MLIR sparse tensors before calling the origal method. // // In particular, each sparse tensor input // // void foo(..., t, ...) { } // -// adds the following strucuture in a wrapper +// makes the original foo() internal and adds the following wrapper method // -// void spiface_foo(..., t1..tn, ...) { +// void foo(..., t1..tn, ...) { // t = assemble t1..tn -// foo(..., t, ...) +// _internal_foo(..., t, ...) // } // // and likewise, each output tensor // // ... T ... bar(...) { return ..., t, ...; } // -// adds the following structure in a wrapper +// makes the original bar() internal and adds the following wrapper method // -// ... T1..TN ... spiface_bar(..., t1'..tn') { -// ..., t, ... = bar(...) +// ... T1..TN ... bar(..., t1'..tn') { +// ..., t, ... = _internal_bar(...) // t1..tn = disassemble t, t1'..tn' // return ..., t1..tn, ... // } @@ -168,9 +168,8 @@ struct SparseFuncAssembler : public OpRewritePattern { LogicalResult matchAndRewrite(func::FuncOp funcOp, PatternRewriter &rewriter) const override { - // Only a rewrite an entry with the c-interface requested. - if (!funcOp->getAttrOfType( - LLVM::LLVMDialect::getEmitCWrapperAttrName())) + // Only rewrite public entry methods. + if (funcOp.isPrivate()) return failure(); // Translate sparse tensor types to external types. @@ -180,29 +179,29 @@ struct SparseFuncAssembler : public OpRewritePattern { convTypes(funcOp.getArgumentTypes(), inputTypes); convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes); - // Only sparse inputs or outputs need a wrapper function. + // Only sparse inputs or outputs need a wrapper method. if (inputTypes.size() == funcOp.getArgumentTypes().size() && outputTypes.size() == funcOp.getResultTypes().size()) return failure(); - // Start the new wrapper function. Together with the c-interface mangling, - // a sparse external entry point eventually will have a name like: - // _mlir_ciface_spiface_XXX(...) + // Modify the original method into an internal, private method. + auto orgName = funcOp.getName(); + std::string wrapper = llvm::formatv("_internal_{0}", orgName).str(); + funcOp.setName(wrapper); + funcOp.setPrivate(); + + // Start the new public wrapper method with original name. Location loc = funcOp.getLoc(); ModuleOp modOp = funcOp->getParentOfType(); MLIRContext *context = modOp.getContext(); OpBuilder moduleBuilder(modOp.getBodyRegion()); - std::string wrapper = llvm::formatv("spiface_{0}", funcOp.getName()).str(); unsigned extra = inputTypes.size(); inputTypes.append(extraTypes); auto func = moduleBuilder.create( - loc, wrapper, FunctionType::get(context, inputTypes, outputTypes)); + loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); func.setPublic(); - func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), - UnitAttr::get(context)); - // Construct new wrapper function body. - auto org = SymbolRefAttr::get(context, funcOp.getName()); + // Construct new wrapper method body. OpBuilder::InsertionGuard insertionGuard(rewriter); Block *body = func.addEntryBlock(); rewriter.setInsertionPointToStart(body); @@ -212,7 +211,8 @@ struct SparseFuncAssembler : public OpRewritePattern { convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), ValueRange(), inputs, 0, /*isIn=*/true); - // Call original function. + // Call original, now internal method. + auto org = SymbolRefAttr::get(context, wrapper); auto call = rewriter.create(loc, funcOp.getResultTypes(), org, inputs); @@ -222,8 +222,13 @@ struct SparseFuncAssembler : public OpRewritePattern { body->getArguments(), outputs, extra, /*isIn=*/false); rewriter.create(loc, outputs); - // Strip the c-interface attribute from the original function. - funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); + // Finally, migrate a potential c-interface property. + if (funcOp->getAttrOfType( + LLVM::LLVMDialect::getEmitCWrapperAttrName())) { + func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), + UnitAttr::get(context)); + funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); + } return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir index 57df8aca3a6a5..c17ba13e86c92 100644 --- a/mlir/test/Dialect/SparseTensor/external.mlir +++ b/mlir/test/Dialect/SparseTensor/external.mlir @@ -3,95 +3,100 @@ // ----- // CHECK-LABEL: func.func @nop( -// CHECK-SAME: %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> attributes {llvm.emit_c_interface} { +// CHECK-SAME: %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> { // CHECK: return %[[A]] : tensor<100xf32> // CHECK: } -func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes { llvm.emit_c_interface } { +func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> { return %arg0 : tensor<100xf32> } // ----- -// CHECK-LABEL: func.func @spiface_sparse_in( +// CHECK-LABEL: func.func @sparse_in( // CHECK-SAME: %[[A:.*]]: tensor, // CHECK-SAME: %[[B:.*]]: tensor, -// CHECK-SAME: %[[C:.*]]: tensor) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} { +// CHECK-SAME: %[[C:.*]]: tensor) -> tensor<64x64xf32> { // CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]] -// CHECK: %[[F:.*]] = call @sparse_in(%[[I]]) +// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]]) // CHECK: return %[[F]] : tensor<64x64xf32> // CHECK: } +// CHECK: func.func private @_internal_sparse_in #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> -func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } { +func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> { %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32> return %0 : tensor<64x64xf32> } // ----- -// CHECK-LABEL: func.func @spiface_sparse_in2( +// CHECK-LABEL: func.func @sparse_in2( // CHECK-SAME: %[[X:.*]]: tensor<100xf32>, // CHECK-SAME: %[[A:.*]]: tensor, // CHECK-SAME: %[[B:.*]]: tensor, -// CHECK-SAME: %[[C:.*]]: tensor) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} { +// CHECK-SAME: %[[C:.*]]: tensor) -> tensor<64x64xf32> { // CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]] -// CHECK: %[[F:.*]] = call @sparse_in2(%[[X]], %[[I]]) +// CHECK: %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]]) // CHECK: return %[[F]] : tensor<64x64xf32> // CHECK: } +// CHECK: func.func private @_internal_sparse_in2 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> -func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } { +func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> { %0 = sparse_tensor.convert %arg1 : tensor<64x64xf32, #sparse> to tensor<64x64xf32> return %0 : tensor<64x64xf32> } // ----- -// CHECK-LABEL: func.func @spiface_sparse_out( +// CHECK-LABEL: func.func @sparse_out( // CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>, // CHECK-SAME: %[[A:.*]]: tensor, // CHECK-SAME: %[[B:.*]]: tensor, -// CHECK-SAME: %[[C:.*]]: tensor) -> (tensor, tensor, tensor) attributes {llvm.emit_c_interface} { -// CHECK: %[[F:.*]] = call @sparse_out(%[[X]]) +// CHECK-SAME: %[[C:.*]]: tensor) -> (tensor, tensor, tensor) { +// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]]) // CHECK: sparse_tensor.disassemble %[[F]] // CHECK: return // CHECK: } +// CHECK: func.func private @_internal_sparse_out #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> -func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } { +func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> { %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse> return %0 : tensor<64x64xf32, #sparse> } // ----- -// CHECK-LABEL: func.func @spiface_sparse_out2( +// CHECK-LABEL: func.func @sparse_out2( // CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>, // CHECK-SAME: %[[A:.*]]: tensor, // CHECK-SAME: %[[B:.*]]: tensor, -// CHECK-SAME: %[[C:.*]]: tensor) -> (tensor<64x64xf32>, tensor, tensor, tensor) attributes {llvm.emit_c_interface} { -// CHECK: %[[F:.*]]:2 = call @sparse_out2(%[[X]]) +// CHECK-SAME: %[[C:.*]]: tensor) -> (tensor<64x64xf32>, tensor, tensor, tensor) { +// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]]) // CHECK: sparse_tensor.disassemble %[[F]]#1 // CHECK: return %[[F]]#0 // CHECK: } +// CHECK: func.func private @_internal_sparse_out2 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> -func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) attributes { llvm.emit_c_interface } { +func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) { %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse> return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse> } // ----- -// CHECK-LABEL: func.func @spiface_sparse_inout( +// CHECK-LABEL: func.func @sparse_inout( // CHECK-SAME: %[[A:.*0]]: tensor, // CHECK-SAME: %[[B:.*1]]: tensor, // CHECK-SAME: %[[C:.*2]]: tensor, // CHECK-SAME: %[[D:.*3]]: tensor, // CHECK-SAME: %[[E:.*4]]: tensor, -// CHECK-SAME: %[[F:.*5]]: tensor) -> (tensor, tensor, tensor) attributes {llvm.emit_c_interface} { +// CHECK-SAME: %[[F:.*5]]: tensor) -> (tensor, tensor, tensor) { // CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]] -// CHECK: %[[F:.*]] = call @sparse_inout(%[[I]]) +// CHECK: %[[F:.*]] = call @_internal_sparse_inout(%[[I]]) // CHECK: sparse_tensor.disassemble %[[F]] // CHECK: return // CHECK: } +// CHECK: func.func private @_internal_sparse_inout #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> -func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } { +func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> { return %arg0 : tensor<64x64xf32, #sparse> } diff --git a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir new file mode 100644 index 0000000000000..f29e6b143783a --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-opt %s --sparse-assembler | FileCheck %s --check-prefix=CHECK-HI +// RUN: mlir-opt %s --sparse-assembler \ +// RUN: --linalg-generalize-named-ops \ +// RUN: --linalg-fuse-elementwise-ops \ +// RUN: --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID +// RUN: mlir-opt %s --sparse-assembler \ +// RUN: --sparsifier | FileCheck %s --check-prefix=CHECK-LOW + +// +// An example of a module generated by torch-mlir with a sparse tensor from +// torch.sparse. The MLIR sparsifier should be able to provide the external +// API through a wrapper method (spiface and ciface). Various passes should +// compose without trouble. +// + +// CHECK-HI-LABEL: func.func @main +// CHECK-HI: sparse_tensor.assemble +// CHECK-HI: call @_internal_main +// CHECK-HI: return +// CHECK-HI: func.func private @_internal_main +// CHECK-HI: linalg.matmul +// CHECK-HI: return +// +// CHECK-MID-LABEL: func.func @main +// CHECK-MID: memref.load +// CHECK-MID: call @_internal_main +// CHECK-MID: return +// CHECK-MID: func.func private @_internal_main +// CHECK-MID: scf.for +// CHECK-MID: scf.for +// CHECK-MID: return + +// CHECK-LOW-LABEL: llvm.func @main +// CHECK-LOW: llvm.call @_internal_main +// CHECK-LOW: llvm.return +// CHECK-LOW: llvm.func @_mlir_ciface_main +// CHECK-LOW: llvm.call @main +// CHECK-LOW: llvm.return +// CHECK-LOW: llvm.func @_internal_main +// CHECK-SAME: {sym_visibility = "private"} +// CHECK-LOW: llvm.return + +#csc = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +module { + func.func @main(%arg0: tensor<64x64xf32, #csc>, + %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<64x64xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x64xf32>) -> tensor<64x64xf32> + %2 = linalg.matmul + ins(%arg0, %arg1 : tensor<64x64xf32, #csc>, tensor<64x64xf32>) + outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32> + return %2 : tensor<64x64xf32> + } +} From 2745fd76e353f82cb506298c8bde5e2bd221b826 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 5 Feb 2024 10:11:13 -0800 Subject: [PATCH 2/2] typo --- mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp index b4cefec8fb21f..98f9d15d09fa3 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp @@ -134,7 +134,7 @@ namespace { // A rewriting rules that converts public entry methods that use sparse tensors // as input parameters and/or output return values into wrapper methods that // [dis]assemble the individual tensors that constitute the actual storage used -// externally into MLIR sparse tensors before calling the origal method. +// externally into MLIR sparse tensors before calling the original method. // // In particular, each sparse tensor input //