From 837d02f1123649006ea6935671da5f8dbdb97560 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Fri, 14 Jun 2024 16:56:47 +0000 Subject: [PATCH 1/2] [mlir][sparse] add conanicalization patterns for IterateOp. --- .../SparseTensor/IR/SparseTensorOps.td | 8 +++++ .../SparseTensor/IR/SparseTensorDialect.cpp | 34 +++++++++++++++++++ .../Dialect/SparseTensor/canonicalize.mlir | 18 ++++++++++ 3 files changed, 60 insertions(+) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 5ae6f9f3443f8..a20de92d2d3ed 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1601,6 +1601,13 @@ def IterateOp : SparseTensor_Op<"iterate", BlockArgument getIterator() { return getRegion().getArguments().front(); } + std::optional getLvlCrd(Level lvl) { + if (getCrdUsedLvls()[lvl]) { + uint64_t mask = (1 << lvl) - 1; + return getCrds()[llvm::popcount(mask & getCrdUsedLvls())]; + } + return std::nullopt; + } Block::BlockArgListType getCrds() { // The first block argument is iterator, the remaining arguments are // referenced coordinates. @@ -1613,6 +1620,7 @@ def IterateOp : SparseTensor_Op<"iterate", let hasVerifier = 1; let hasRegionVerifier = 1; + let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 232d25d718c65..ac711769ed2ea 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/Bitset.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" @@ -2266,6 +2267,39 @@ LogicalResult ExtractIterSpaceOp::verify() { return success(); } +struct RemoveUnusedLvlCrds : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IterateOp iterateOp, + PatternRewriter &rewriter) const override { + LevelSet newUsedLvls(0); + llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments()); + for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) { + if (auto crd = iterateOp.getLvlCrd(i)) { + if (crd->getUsers().empty()) + toRemove.set(crd->getArgNumber()); + else + newUsedLvls.set(i); + } + } + + // All coordinates are used. + if (toRemove.none()) + return failure(); + + rewriter.startOpModification(iterateOp); + iterateOp.setCrdUsedLvls(newUsedLvls); + iterateOp.getBody()->eraseArguments(toRemove); + rewriter.finalizeOpModification(iterateOp); + return success(); + } +}; + +void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, + mlir::MLIRContext *context) { + results.add(context); +} + ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::Argument iterator; OpAsmParser::UnresolvedOperand iterSpace; diff --git a/mlir/test/Dialect/SparseTensor/canonicalize.mlir b/mlir/test/Dialect/SparseTensor/canonicalize.mlir index b1d3d7916c142..ceb82cab516ed 100644 --- a/mlir/test/Dialect/SparseTensor/canonicalize.mlir +++ b/mlir/test/Dialect/SparseTensor/canonicalize.mlir @@ -21,3 +21,21 @@ func.func @sparse_slice_canonicalize(%arg0 : tensor, %arg1 : i %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor to tensor return %0 : tensor } + +// ----- + +#CSR = #sparse_tensor.encoding<{ + map = (i, j) -> (i : dense, j : compressed) +}> + +// Make sure that the first unused coordinate is optimized. +// CHECK-LABEL: @sparse_iterate_canonicalize +// CHECK: sparse_tensor.iterate {{.*}} at(_, %{{.*}}) +func.func @sparse_iterate_canonicalize(%sp : tensor) { + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 to 2 + : tensor -> !sparse_tensor.iter_space<#CSR, lvls = 0 to 2> + sparse_tensor.iterate %it1 in %l1 at (%coord0, %coord1) : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2> { + "test.op"(%coord1) : (index) -> () + } + return +} From c1e4fb762c0d76741016d00c2cdffd8b56ad1175 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Fri, 14 Jun 2024 17:14:11 +0000 Subject: [PATCH 2/2] address comments --- mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index a20de92d2d3ed..b2089924291cd 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1603,7 +1603,7 @@ def IterateOp : SparseTensor_Op<"iterate", } std::optional getLvlCrd(Level lvl) { if (getCrdUsedLvls()[lvl]) { - uint64_t mask = (1 << lvl) - 1; + uint64_t mask = (static_cast(0x01u) << lvl) - 1; return getCrds()[llvm::popcount(mask & getCrdUsedLvls())]; } return std::nullopt;