From c56970af6ebb9f7dcc242166ba3591c26e07f988 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Mon, 27 Nov 2023 22:22:47 +0000 Subject: [PATCH 1/4] [mlir][sparse] code cleanup, remove FIXMEs --- .../Dialect/SparseTensor/IR/SparseTensor.h | 16 +++++---- .../SparseTensor/IR/SparseTensorDialect.cpp | 36 ++++++------------- .../SparseTensor/Transforms/CodegenUtils.cpp | 26 -------------- .../SparseTensor/Transforms/CodegenUtils.h | 7 ---- .../SparseTensor/Transforms/LoopEmitter.cpp | 6 ++-- .../Transforms/SparseTensorRewriting.cpp | 9 ++--- 6 files changed, 25 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index eb7c50ae2efdf..f102f02701542 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -163,13 +163,15 @@ bool isBlockSparsity(AffineMap dimToLvl); // Reordering. // -/// [deprecated] Convenience method to translate the given level to the -/// corresponding dimension. Requires: `0 <= l < lvlRank`. -Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l); - -/// [deprecated] Convenience method to translate the given dimension to -/// the corresponding level. Requires: `0 <= d < dimRank`. -Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d); +/// Convenience method to translate the given level to the corresponding +/// dimension. +/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`. +Dimension toDim(SparseTensorEncodingAttr enc, Level l); + +/// Convenience method to translate the given dimension to the corresponding +/// level. +/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`. +Level toLvl(SparseTensorEncodingAttr enc, Dimension d); } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 791aeebee5a32..28e07e1669e79 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -375,14 +375,12 @@ SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { std::optional SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { - // FIXME: `toOrigDim` is deprecated. - return getStaticDimSliceOffset(toOrigDim(*this, lvl)); + return getStaticDimSliceOffset(toDim(*this, lvl)); } std::optional SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { - // FIXME: `toOrigDim` is deprecated. - return getStaticDimSliceStride(toOrigDim(*this, lvl)); + return getStaticDimSliceStride(toDim(*this, lvl)); } SmallVector @@ -399,9 +397,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef srcShape, if (isPermutation()) { for (unsigned r = 0; r < rank; r++) { // FIXME: `toOrigDim` and `toStoredDim` are deprecated. - unsigned trans = dir == CrdTransDirectionKind::dim2lvl - ? toOrigDim(*this, r) - : toStoredDim(*this, r); + unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) + : toLvl(*this, r); ret.push_back(srcShape[trans]); } return ret; @@ -925,31 +922,20 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src, ordered); } -// TODO: Remove this definition once all use-sites have been fixed to -// properly handle non-permutations. -Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc, - Level l) { +Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { if (enc) { - if (const auto dimToLvl = enc.getDimToLvl()) { - assert(enc.isPermutation()); + assert(enc.isPermutation() && "Non permutation map"); + if (const auto dimToLvl = enc.getDimToLvl()) return dimToLvl.getDimPosition(l); - } } return l; } -// TODO: Remove this definition once all use-sites have been fixed to -// properly handle non-permutations. -Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc, - Dimension d) { +Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { if (enc) { - if (const auto dimToLvl = enc.getDimToLvl()) { - assert(enc.isPermutation()); - auto maybePos = - dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext())); - assert(maybePos.has_value()); - return *maybePos; - } + assert(enc.isPermutation() && ""); + if (const auto lvlToDim = enc.getLvlToDim()) + return lvlToDim.getDimPosition(d); } return d; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index 1200b999f9a90..33d449aac5a35 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -546,32 +546,6 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem, } } -Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, - ValueRange dimSizes, - Value valuesBuffer, - Value lvlCoords) { - // Reuse the `lvlCoords` buffer to store the level-sizes. - const Level lvlRank = enc.getLvlRank(); - SmallVector lvlSizes; - lvlSizes.reserve(lvlRank); - for (Level l = 0; l < lvlRank; l++) - // FIXME: `toOrigDim` is deprecated. - lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]); - storeAll(builder, loc, lvlCoords, lvlSizes); - // The memref ReshapeOp requires the sizes buffer to have a static - // shape. - const auto iTp = builder.getIndexType(); - const SmallVector lvlSizesShape{static_cast(lvlRank)}; - const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp); - lvlCoords = builder.create(loc, lvlSizesTp, lvlCoords); - // Finally, create the ReshapeOp. - const SmallVector resShape(lvlRank, ShapedType::kDynamic); - const Type elemTp = getMemRefType(valuesBuffer).getElementType(); - const auto resTp = MemRefType::get(resShape, elemTp); - return builder.create(loc, resTp, valuesBuffer, lvlCoords); -} - TypedValue sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { auto tTp = llvm::cast(tensor.getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index cb0acdd2be9f7..0ce33427281f5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -277,13 +277,6 @@ SmallVector loadAll(OpBuilder &builder, Location loc, size_t size, void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx = 0, Value offsetVal = Value()); -/// Reshapes the linear values buffer for an annotated all dense sparse tensor -/// to match the shape of the corresponding dense tensor to support direct -/// access of the buffer through `lvlCoords`. -Value reshapeValuesToLevels(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ValueRange dimSizes, - Value valuesBuffer, Value lvlCoords); - // Generates code to cast a tensor to a memref. TypedValue genToMemref(OpBuilder &builder, Location loc, Value tensor); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp index f8bcc0fe12a10..413a835ff14d3 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -68,15 +68,13 @@ static constexpr unsigned kSliceIterWidth = 3; static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, Level lvl) { auto enc = getSparseTensorEncoding(tensor.getType()); - // FIXME: `toOrigDim` is deprecated - return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl)); + return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl)); } static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, Level lvl) { auto enc = getSparseTensorEncoding(tensor.getType()); - // FIXME: `toOrigDim` is deprecated - return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl)); + return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl)); } /// Converts a coordinate relative to the slice to the coordinate relative diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 5374ab55c5c0d..103908b2cf5bd 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -661,8 +661,7 @@ struct TensorReshapeRewriter : public OpRewritePattern { SmallVector srcDcvs; srcDcvs.reserve(srcRank); for (Dimension d = 0; d < srcRank; d++) { - // FIXME: `toStoredDim` is deprecated - Level lvl = toStoredDim(encSrc, d); + Level lvl = toLvl(encSrc, d); srcDcvs.push_back(srcLcvs[lvl]); } @@ -766,8 +765,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern { SmallVector srcDcvs; srcDcvs.reserve(dimRank); for (Dimension d = 0; d < dimRank; d++) { - // FIXME: `toStoredDim` is deprecated - Level lvl = toStoredDim(encSrc, d); + Level lvl = toLvl(encSrc, d); srcDcvs.push_back(srcLcvs[lvl]); } SmallVector dstDcvs; @@ -872,9 +870,8 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern { return failure(); if (stt.isPermutation()) { - // FIXME: `toStoredDim` is deprecated rewriter.replaceOpWithNewOp(op, op.getSource(), - toStoredDim(stt.getEncoding(), *dim)); + toLvl(stt.getEncoding(), *dim)); return success(); } From 80ca0bb13c336c2ac76bd2d3a7236d005a05ffb5 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Mon, 27 Nov 2023 22:33:43 +0000 Subject: [PATCH 2/4] remove unused variables --- .../SparseTensor/Transforms/IterationGraphSorter.h | 1 - .../SparseTensor/Transforms/SparseReinterpretMap.cpp | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h index 613a8609ac097..52ee117029300 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h @@ -26,7 +26,6 @@ enum class SortMask : unsigned { // The individual mask bits. kIncludeDenseOutput = 0x1, // b001 kIncludeDenseInput = 0x2, // b010 - kIncludeUndef = 0x4, // b100 // The subsets of mask bits. kIncludeAll = 0x7, // b111 kIncludeDense = 0x3, // b011 diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp index 268bd8fbe2738..c94ef8b962877 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp @@ -422,10 +422,10 @@ struct GenericOpScheduler : public OpRewritePattern { // computation. Must be ordered from more strict to less strict. // Ideally (though might not be guaranteed), the earlier a constraint mask // can be satisfied, the faster the generated kernel will be. - const auto allMasks = { - SortMask::kIncludeAll, SortMask::kIncludeDense, - SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput, - SortMask::kIncludeUndef, SortMask::kSparseOnly}; + const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense, + SortMask::kIncludeDenseInput, + SortMask::kIncludeDenseOutput, + SortMask::kSparseOnly}; for (const SortMask mask : allMasks) { order = scheduler.sort(mask); if (order) { From c345ebd7dd5644a5d7c6cb5b95c9eb66b5001055 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Mon, 27 Nov 2023 22:37:33 +0000 Subject: [PATCH 3/4] small fix --- mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 28e07e1669e79..df37dfab9a2ef 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -924,7 +924,7 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src, Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { if (enc) { - assert(enc.isPermutation() && "Non permutation map"); + assert(enc.isPermutation() && "Non permutation map not supported"); if (const auto dimToLvl = enc.getDimToLvl()) return dimToLvl.getDimPosition(l); } @@ -933,7 +933,7 @@ Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { if (enc) { - assert(enc.isPermutation() && ""); + assert(enc.isPermutation() && "Non permutation map not supported"); if (const auto lvlToDim = enc.getLvlToDim()) return lvlToDim.getDimPosition(d); } From efc542a30384ccd608d30306014cb12ad59f1c36 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Mon, 27 Nov 2023 22:38:35 +0000 Subject: [PATCH 4/4] small fix --- mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index df37dfab9a2ef..fc897e7935510 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -396,7 +396,6 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef srcShape, if (isPermutation()) { for (unsigned r = 0; r < rank; r++) { - // FIXME: `toOrigDim` and `toStoredDim` are deprecated. unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) : toLvl(*this, r); ret.push_back(srcShape[trans]);