diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index d6d038ef65bdf..3043a0c4dc410 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -248,6 +248,12 @@ std::unique_ptr createSparsificationAndBufferizationPass( bool enableBufferInitialization, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen); +//===----------------------------------------------------------------------===// +// Sparse Iteration Transform Passes +//===----------------------------------------------------------------------===// + +std::unique_ptr createSparseSpaceCollapsePass(); + //===----------------------------------------------------------------------===// // Registration. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 2f844cee5ff52..c6554e1c94a4a 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -464,4 +464,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M ]; } +//===----------------------------------------------------------------------===// +// Sparse Iteration Transform Passes +//===----------------------------------------------------------------------===// + +def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> { + let summary = "sparse space collapsing pass"; + let description = [{ + This pass collapses consecutive sparse spaces (extracted from the same tensor) + into one multi-dimensional space. The pass is not yet stablized. + }]; + let constructor = "mlir::createSparseSpaceCollapsePass()"; + let dependentDialects = [ + "sparse_tensor::SparseTensorDialect", + ]; +} + #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt index af3a1b48f45af..2a29ee8a7a87c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms SparseGPUCodegen.cpp SparseReinterpretMap.cpp SparseStorageSpecifierToLLVM.cpp + SparseSpaceCollapse.cpp SparseTensorCodegen.cpp SparseTensorConversion.cpp SparseTensorPasses.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp new file mode 100644 index 0000000000000..924046fcd9961 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp @@ -0,0 +1,199 @@ +//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/Passes.h" + +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" + +namespace mlir { +#define GEN_PASS_DEF_SPARSESPACECOLLAPSE +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" +} // namespace mlir + +#define DEBUG_TYPE "sparse-space-collapse" + +using namespace mlir; +using namespace sparse_tensor; + +namespace { + +struct CollapseSpaceInfo { + ExtractIterSpaceOp space; + IterateOp loop; +}; + +bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) { + auto pIterArgs = parent.getRegionIterArgs(); + auto nInitArgs = node.getInits(); + if (pIterArgs.size() != nInitArgs.size()) + return false; + + // Two loops are collapsable if they are perfectly nested. + auto pYields = parent.getYieldedValues(); + auto nResult = node.getLoopResults().value(); + + bool yieldEq = + llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) { + return std::get<0>(zipped) == std::get<1>(zipped); + }); + + // Parent iter_args should be passed directly to the node's init_args. + bool iterArgEq = + llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) { + return std::get<0>(zipped) == std::get<1>(zipped); + }); + + return yieldEq && iterArgEq; +} + +bool legalToCollapse(SmallVectorImpl &toCollapse, + ExtractIterSpaceOp curSpace) { + + auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp { + Value spaceVal = space.getExtractedSpace(); + if (spaceVal.hasOneUse()) + return llvm::dyn_cast(*spaceVal.getUsers().begin()); + return nullptr; + }; + + if (toCollapse.empty()) { + // Collapse root. + if (auto itOp = getIterateOpOverSpace(curSpace)) { + CollapseSpaceInfo &info = toCollapse.emplace_back(); + info.space = curSpace; + info.loop = itOp; + return true; + } + return false; + } + + auto parent = toCollapse.back().space; + auto pItOp = toCollapse.back().loop; + auto nItOp = getIterateOpOverSpace(curSpace); + + // Can only collapse spaces extracted from the same tensor. + if (parent.getTensor() != curSpace.getTensor()) { + LLVM_DEBUG({ + llvm::dbgs() + << "failed to collpase spaces extracted from different tensors."; + }); + return false; + } + + // Can only collapse consecutive simple iteration on one tensor (i.e., no + // coiteration). + if (!nItOp || nItOp->getBlock() != curSpace->getBlock() || + pItOp.getIterator() != curSpace.getParentIter() || + curSpace->getParentOp() != pItOp.getOperation()) { + LLVM_DEBUG( + { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; }); + return false; + } + + if (pItOp && !isCollapsableLoops(pItOp, nItOp)) { + LLVM_DEBUG({ + llvm::dbgs() + << "failed to collapse IterateOps that are not perfectly nested."; + }); + return false; + } + + CollapseSpaceInfo &info = toCollapse.emplace_back(); + info.space = curSpace; + info.loop = nItOp; + return true; +} + +void collapseSparseSpace(MutableArrayRef toCollapse) { + if (toCollapse.size() < 2) + return; + + ExtractIterSpaceOp root = toCollapse.front().space; + ExtractIterSpaceOp leaf = toCollapse.back().space; + Location loc = root.getLoc(); + + assert(root->hasOneUse() && leaf->hasOneUse()); + + // Insert collapsed operation at the same scope as root operation. + OpBuilder builder(root); + + // Construct the collapsed iteration space. + auto collapsedSpace = builder.create( + loc, root.getTensor(), root.getParentIter(), root.getLoLvl(), + leaf.getHiLvl()); + + auto rItOp = llvm::cast(*root->getUsers().begin()); + auto innermost = toCollapse.back().loop; + + IRMapping mapper; + mapper.map(leaf, collapsedSpace.getExtractedSpace()); + for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs())) + mapper.map(std::get<0>(z), std::get<1>(z)); + + auto cloned = llvm::cast(builder.clone(*innermost, mapper)); + builder.setInsertionPointToStart(cloned.getBody()); + + LevelSet crdUsedLvls; + unsigned shift = 0, argIdx = 1; + for (auto info : toCollapse.drop_back()) { + LevelSet set = info.loop.getCrdUsedLvls(); + crdUsedLvls |= set.lshift(shift); + shift += info.loop.getSpaceDim(); + for (BlockArgument crd : info.loop.getCrds()) { + BlockArgument collapsedCrd = cloned.getBody()->insertArgument( + argIdx++, builder.getIndexType(), crd.getLoc()); + crd.replaceAllUsesWith(collapsedCrd); + } + } + crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift); + cloned.getIterator().setType(collapsedSpace.getType().getIteratorType()); + cloned.setCrdUsedLvls(crdUsedLvls); + + rItOp.replaceAllUsesWith(cloned.getResults()); + // Erase collapsed loops. + rItOp.erase(); + root.erase(); +} + +struct SparseSpaceCollapsePass + : public impl::SparseSpaceCollapseBase { + SparseSpaceCollapsePass() = default; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + // A naive (experimental) implementation to collapse consecutive sparse + // spaces. It does NOT handle complex cases where multiple spaces are + // extracted in the same basic block. E.g., + // + // %space1 = extract_space %t1 ... + // %space2 = extract_space %t2 ... + // sparse_tensor.iterate(%sp1) ... + // + SmallVector toCollapse; + func->walk([&](ExtractIterSpaceOp op) { + if (!legalToCollapse(toCollapse, op)) { + // if not legal to collapse one more space, collapse the existing ones + // and clear. + collapseSparseSpace(toCollapse); + toCollapse.clear(); + } + }); + + collapseSparseSpace(toCollapse); + } +}; + +} // namespace + +std::unique_ptr mlir::createSparseSpaceCollapsePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir new file mode 100644 index 0000000000000..baa6199f12bc3 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s --sparse-space-collapse | FileCheck %s + +#COO = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : compressed(nonunique), + j : singleton(soa) + ) +}> + +// CHECK-LABEL: func.func @sparse_sparse_collapse( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>, +// CHECK-SAME: %[[VAL_1:.*]]: index) { +// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse> +// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) +// CHECK: %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index +// CHECK: sparse_tensor.yield %[[VAL_8]] : index +// CHECK: } +// CHECK: "test.sink"(%[[VAL_4]]) : (index) -> () +// CHECK: return +// CHECK: } +func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) { + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 + : tensor<4x8xf32, #COO> + -> !sparse_tensor.iter_space<#COO, lvls = 0> + %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index { + %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 + : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> + -> !sparse_tensor.iter_space<#COO, lvls = 1> + %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index { + %k ="test.op"(%inner) : (index) -> index + sparse_tensor.yield %k : index + } + sparse_tensor.yield %r2 : index + } + "test.sink"(%r1) : (index) -> () + return +}