diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index c1e217675020f..c537e92a51d53 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -87,6 +87,15 @@ std::unique_ptr createSparsificationPass(); std::unique_ptr createSparsificationPass(const SparsificationOptions &options); +//===----------------------------------------------------------------------===// +// The StageSparseOperations pass. +//===----------------------------------------------------------------------===// + +/// Sets up StageSparseOperation rewriting rules. +void populateStageSparseOperationsPatterns(RewritePatternSet &patterns); + +std::unique_ptr createStageSparseOperationsPass(); + //===----------------------------------------------------------------------===// // The PostSparsificationRewriting pass. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index d8d5dbb5ad3ce..8f116bff9b185 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -123,6 +123,18 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> { ]; } +def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> { + let summary = "Decompose a complex sparse operation into multiple stages"; + let description = [{ + A pass that decomposes a complex sparse operation into multiple stages. + E.g., CSR -> CSC is staged into CSR -> COO (unordered) -> sort -> CSC. + }]; + let constructor = "mlir::createStageSparseOperationsPass()"; + let dependentDialects = [ + "sparse_tensor::SparseTensorDialect", + ]; +} + def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> { let summary = "Applies sparse tensor rewriting rules after sparsification"; let description = [{ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt index 5ef9d906f0e8b..0ca6668c8c747 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms SparseVectorization.cpp Sparsification.cpp SparsificationAndBufferizationPass.cpp + StageSparseOperations.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index f50d3d4606554..e1f88ad9c0e11 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -30,6 +30,7 @@ namespace mlir { #define GEN_PASS_DEF_SPARSEBUFFERREWRITE #define GEN_PASS_DEF_SPARSEVECTORIZATION #define GEN_PASS_DEF_SPARSEGPUCODEGEN +#define GEN_PASS_DEF_STAGESPARSEOPERATIONS #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" } // namespace mlir @@ -92,6 +93,18 @@ struct SparsificationPass } }; +struct StageSparseOperationsPass + : public impl::StageSparseOperationsBase { + StageSparseOperationsPass() = default; + StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default; + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateStageSparseOperationsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct PostSparsificationRewritePass : public impl::PostSparsificationRewriteBase< PostSparsificationRewritePass> { @@ -384,6 +397,10 @@ mlir::createSparsificationPass(const SparsificationOptions &options) { return std::make_unique(options); } +std::unique_ptr mlir::createStageSparseOperationsPass() { + return std::make_unique(); +} + std::unique_ptr mlir::createPostSparsificationRewritePass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp new file mode 100644 index 0000000000000..4adc4d131198c --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp @@ -0,0 +1,4 @@ +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" + +void mlir::populateStageSparseOperationsPatterns( + RewritePatternSet & /*patterns*/) {}