diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index 212f7b6f13c26..af64370a62dd7 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -55,6 +55,16 @@ enum class SparseEmitStrategy { kDebugInterface, // generate only place-holder for sparse iteration }; +namespace sparse_tensor { + +/// Defines a strategy for loop ordering during sparse code generation. +enum class LoopOrderingStrategy : unsigned { + kDefault, ///< Default strategy (eagerly selects last loop in topological + ///< sort). +}; + +} // namespace sparse_tensor + #define GEN_PASS_DECL #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" @@ -71,11 +81,16 @@ std::unique_ptr createSparseAssembler(bool directOut); // The SparseReinterpretMap pass. //===----------------------------------------------------------------------===// -void populateSparseReinterpretMap(RewritePatternSet &patterns, - ReinterpretMapScope scope); +void populateSparseReinterpretMap( + RewritePatternSet &patterns, ReinterpretMapScope scope, + sparse_tensor::LoopOrderingStrategy strategy = + sparse_tensor::LoopOrderingStrategy::kDefault); std::unique_ptr createSparseReinterpretMapPass(); std::unique_ptr createSparseReinterpretMapPass(ReinterpretMapScope scope); +std::unique_ptr +createSparseReinterpretMapPass(ReinterpretMapScope scope, + sparse_tensor::LoopOrderingStrategy strategy); //===----------------------------------------------------------------------===// // The PreSparsificationRewriting pass. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 2513e106f5b06..75e77d67db1b3 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -81,6 +81,11 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> { clEnumValN(mlir::ReinterpretMapScope::kExceptGeneric, "except-generic", "Run on operations expect linalg.generic (e.g., foreach)"))}]>, + Option<"loopOrderingStrategy", "loop-ordering-strategy", "mlir::sparse_tensor::LoopOrderingStrategy", + "mlir::sparse_tensor::LoopOrderingStrategy::kDefault", + "Set the loop ordering strategy for sparse code generation", [{llvm::cl::values( + clEnumValN(mlir::sparse_tensor::LoopOrderingStrategy::kDefault, "default", + "Default strategy (eagerly selects last loop in topological sort)"))}]>, ]; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp index a1e35b87399ca..0fc5cc76de39c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp @@ -59,7 +59,7 @@ struct DemapInsRewriter : public OpRewritePattern { // Flattens an affine expression into a list of AffineDimExprs. struct AffineDimCollector : public AffineExprVisitor { - explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){}; + explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {}; void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); } BitVector dims; }; @@ -67,7 +67,7 @@ struct AffineDimCollector : public AffineExprVisitor { // Flattens an affine expression into a list of AffineDimExprs. struct AffineExprAdmissibleVisitor : public AffineExprVisitor { - explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput){}; + explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {}; // We only allow AffineDimExpr on output. void visitAddExpr(AffineBinaryOpExpr expr) { @@ -407,7 +407,10 @@ struct GenericOpReinterpretMap }; struct GenericOpScheduler : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + GenericOpScheduler(MLIRContext *context, + sparse_tensor::LoopOrderingStrategy strategy) + : OpRewritePattern(context), strategy(strategy) {} + LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, PatternRewriter &rewriter) const override { if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() || @@ -420,7 +423,8 @@ struct GenericOpScheduler : public OpRewritePattern { if (linalgOp->hasAttr(sorted)) return failure(); - auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp); + // Pass strategy to IterationGraphSorter. + auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp, strategy); bool isAdmissible = false; AffineMap order; // A const list of all masks that we used for iteration graph @@ -582,6 +586,9 @@ struct GenericOpScheduler : public OpRewritePattern { // TODO: convert more than one? return failure(); } + +private: + sparse_tensor::LoopOrderingStrategy strategy; }; //===----------------------------------------------------------------------===// @@ -786,12 +793,13 @@ struct ForeachOpDemapper } // namespace -void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, - ReinterpretMapScope scope) { +void mlir::populateSparseReinterpretMap( + RewritePatternSet &patterns, ReinterpretMapScope scope, + sparse_tensor::LoopOrderingStrategy strategy) { if (scope == ReinterpretMapScope::kAll || scope == ReinterpretMapScope::kGenericOnly) { - patterns.add( - patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), strategy); } if (scope == ReinterpretMapScope::kAll || scope == ReinterpretMapScope::kExceptGeneric) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index 153b9b170e5d3..b660e22154688 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -67,12 +67,13 @@ struct SparseReinterpretMap SparseReinterpretMap(const SparseReinterpretMap &pass) = default; SparseReinterpretMap(const SparseReinterpretMapOptions &options) { scope = options.scope; + loopOrderingStrategy = options.loopOrderingStrategy; } void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - populateSparseReinterpretMap(patterns, scope); + populateSparseReinterpretMap(patterns, scope, loopOrderingStrategy); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -438,6 +439,14 @@ mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) { return std::make_unique(options); } +std::unique_ptr mlir::createSparseReinterpretMapPass( + ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy) { + SparseReinterpretMapOptions options; + options.scope = scope; + options.loopOrderingStrategy = strategy; + return std::make_unique(options); +} + std::unique_ptr mlir::createPreSparsificationRewritePass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index c7e463a5a5b49..73e0f3d2891d7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -100,7 +100,15 @@ AffineMap IterationGraphSorter::topoSort() { // We always prefer a parallel loop over a reduction loop because putting // a reduction loop early might make the loop sequence inadmissible. auto &it = !parIt.empty() ? parIt : redIt; - auto src = it.back(); + + // Select loop based on strategy. + unsigned src; + switch (strategy) { + case sparse_tensor::LoopOrderingStrategy::kDefault: + src = it.back(); + break; + } + loopOrder.push_back(src); it.pop_back(); // Update in-degree, and push 0-degree node into worklist. @@ -122,8 +130,8 @@ AffineMap IterationGraphSorter::topoSort() { return AffineMap(); } -IterationGraphSorter -IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) { +IterationGraphSorter IterationGraphSorter::fromGenericOp( + linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy) { // Must be a demapped sparse kernel. assert(!hasAnyNonIdentityOperandsOrResults(genericOp) && hasAnySparseOperandOrResult(genericOp) && @@ -140,14 +148,16 @@ IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) { genericOp.getIteratorTypesArray(); return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap, - std::move(iterTypes)); + std::move(iterTypes), strategy); } IterationGraphSorter::IterationGraphSorter( SmallVector &&ins, SmallVector &&loop2InsLvl, Value out, - AffineMap loop2OutLvl, SmallVector &&iterTypes) + AffineMap loop2OutLvl, SmallVector &&iterTypes, + sparse_tensor::LoopOrderingStrategy strategy) : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out), - loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) { + loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)), + strategy(strategy) { // One map per tensor. assert(loop2InsLvl.size() == ins.size()); // All the affine maps have the same number of dimensions (loops). diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h index a6abe9eb76c47..b2a16e9382758 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_ #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_ +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/IR/AffineMap.h" namespace mlir { @@ -41,9 +42,12 @@ enum class SortMask : unsigned { class IterationGraphSorter { public: - /// Factory method that construct an iteration graph sorter - /// for the given linalg.generic operation. - static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp); + /// Factory method that constructs an iteration graph sorter + /// for the given linalg.generic operation with a specific loop ordering + /// strategy. + static IterationGraphSorter + fromGenericOp(linalg::GenericOp genericOp, + sparse_tensor::LoopOrderingStrategy strategy); /// Returns a permutation that represents the scheduled loop order. /// Note that the returned AffineMap could be null if the kernel @@ -58,7 +62,9 @@ class IterationGraphSorter { IterationGraphSorter(SmallVector &&ins, SmallVector &&loop2InsLvl, Value out, AffineMap loop2OutLvl, - SmallVector &&iterTypes); + SmallVector &&iterTypes, + sparse_tensor::LoopOrderingStrategy strategy = + sparse_tensor::LoopOrderingStrategy::kDefault); // Adds all the constraints in the given loop to level map. void addConstraints(Value t, AffineMap loop2LvlMap); @@ -84,6 +90,9 @@ class IterationGraphSorter { // InDegree used for topo sort. std::vector inDegree; + + // Loop ordering strategy. + sparse_tensor::LoopOrderingStrategy strategy; }; } // namespace sparse_tensor