Skip to content

[mlir][sparse] code cleanup, remove FIXMEs #73575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 27, 2023
Merged

Conversation

PeimingLiu
Copy link
Member

No description provided.

@PeimingLiu PeimingLiu requested a review from aartbik November 27, 2023 22:25
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Nov 27, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/73575.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+9-7)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+11-25)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp (-26)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h (-7)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+2-4)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+3-6)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index eb7c50ae2efdf8c..f102f0270154264 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -163,13 +163,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
 // Reordering.
 //
 
-/// [deprecated] Convenience method to translate the given level to the
-/// corresponding dimension. Requires: `0 <= l < lvlRank`.
-Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
-
-/// [deprecated] Convenience method to translate the given dimension to
-/// the corresponding level. Requires: `0 <= d < dimRank`.
-Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
+/// Convenience method to translate the given level to the corresponding
+/// dimension.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
+Dimension toDim(SparseTensorEncodingAttr enc, Level l);
+
+/// Convenience method to translate the given dimension to the corresponding
+/// level.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
+Level toLvl(SparseTensorEncodingAttr enc, Dimension d);
 
 } // namespace sparse_tensor
 } // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 791aeebee5a328d..28e07e1669e7919 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -375,14 +375,12 @@ SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
 
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
-  // FIXME: `toOrigDim` is deprecated.
-  return getStaticDimSliceOffset(toOrigDim(*this, lvl));
+  return getStaticDimSliceOffset(toDim(*this, lvl));
 }
 
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
-  // FIXME: `toOrigDim` is deprecated.
-  return getStaticDimSliceStride(toOrigDim(*this, lvl));
+  return getStaticDimSliceStride(toDim(*this, lvl));
 }
 
 SmallVector<int64_t>
@@ -399,9 +397,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
   if (isPermutation()) {
     for (unsigned r = 0; r < rank; r++) {
       // FIXME: `toOrigDim` and `toStoredDim` are deprecated.
-      unsigned trans = dir == CrdTransDirectionKind::dim2lvl
-                           ? toOrigDim(*this, r)
-                           : toStoredDim(*this, r);
+      unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
+                                                             : toLvl(*this, r);
       ret.push_back(srcShape[trans]);
     }
     return ret;
@@ -925,31 +922,20 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
       ordered);
 }
 
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
-                                         Level l) {
+Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
   if (enc) {
-    if (const auto dimToLvl = enc.getDimToLvl()) {
-      assert(enc.isPermutation());
+    assert(enc.isPermutation() && "Non permutation map");
+    if (const auto dimToLvl = enc.getDimToLvl())
       return dimToLvl.getDimPosition(l);
-    }
   }
   return l;
 }
 
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
-                                       Dimension d) {
+Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
   if (enc) {
-    if (const auto dimToLvl = enc.getDimToLvl()) {
-      assert(enc.isPermutation());
-      auto maybePos =
-          dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
-      assert(maybePos.has_value());
-      return *maybePos;
-    }
+    assert(enc.isPermutation() && "");
+    if (const auto lvlToDim = enc.getLvlToDim())
+      return lvlToDim.getDimPosition(d);
   }
   return d;
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 1200b999f9a90ff..33d449aac5a3550 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -546,32 +546,6 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
   }
 }
 
-Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
-                                           SparseTensorEncodingAttr enc,
-                                           ValueRange dimSizes,
-                                           Value valuesBuffer,
-                                           Value lvlCoords) {
-  // Reuse the `lvlCoords` buffer to store the level-sizes.
-  const Level lvlRank = enc.getLvlRank();
-  SmallVector<Value> lvlSizes;
-  lvlSizes.reserve(lvlRank);
-  for (Level l = 0; l < lvlRank; l++)
-    // FIXME: `toOrigDim` is deprecated.
-    lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]);
-  storeAll(builder, loc, lvlCoords, lvlSizes);
-  // The memref ReshapeOp requires the sizes buffer to have a static
-  // shape.
-  const auto iTp = builder.getIndexType();
-  const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
-  const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
-  lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
-  // Finally, create the ReshapeOp.
-  const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
-  const Type elemTp = getMemRefType(valuesBuffer).getElementType();
-  const auto resTp = MemRefType::get(resShape, elemTp);
-  return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
-}
-
 TypedValue<BaseMemRefType>
 sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
   auto tTp = llvm::cast<TensorType>(tensor.getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index cb0acdd2be9f7b0..0ce33427281f598 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -277,13 +277,6 @@ SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size,
 void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
               size_t offsetIdx = 0, Value offsetVal = Value());
 
-/// Reshapes the linear values buffer for an annotated all dense sparse tensor
-/// to match the shape of the corresponding dense tensor to support direct
-/// access of the buffer through `lvlCoords`.
-Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
-                            SparseTensorEncodingAttr enc, ValueRange dimSizes,
-                            Value valuesBuffer, Value lvlCoords);
-
 // Generates code to cast a tensor to a memref.
 TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
                                        Value tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index f8bcc0fe12a1093..413a835ff14d314 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -68,15 +68,13 @@ static constexpr unsigned kSliceIterWidth = 3;
 static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
                             Level lvl) {
   auto enc = getSparseTensorEncoding(tensor.getType());
-  // FIXME: `toOrigDim` is deprecated
-  return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
+  return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
 }
 
 static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
                             Level lvl) {
   auto enc = getSparseTensorEncoding(tensor.getType());
-  // FIXME: `toOrigDim` is deprecated
-  return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
+  return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
 }
 
 /// Converts a coordinate relative to the slice to the coordinate relative
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5374ab55c5c0d9f..103908b2cf5bd89 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -661,8 +661,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(srcRank);
           for (Dimension d = 0; d < srcRank; d++) {
-            // FIXME: `toStoredDim` is deprecated
-            Level lvl = toStoredDim(encSrc, d);
+            Level lvl = toLvl(encSrc, d);
             srcDcvs.push_back(srcLcvs[lvl]);
           }
 
@@ -766,8 +765,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(dimRank);
           for (Dimension d = 0; d < dimRank; d++) {
-            // FIXME: `toStoredDim` is deprecated
-            Level lvl = toStoredDim(encSrc, d);
+            Level lvl = toLvl(encSrc, d);
             srcDcvs.push_back(srcLcvs[lvl]);
           }
           SmallVector<Value> dstDcvs;
@@ -872,9 +870,8 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
       return failure();
 
     if (stt.isPermutation()) {
-      // FIXME: `toStoredDim` is deprecated
       rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
-                                         toStoredDim(stt.getEncoding(), *dim));
+                                         toLvl(stt.getEncoding(), *dim));
       return success();
     }
 

@PeimingLiu PeimingLiu merged commit 4e2f152 into llvm:main Nov 27, 2023
@PeimingLiu PeimingLiu deleted the cleanup branch November 27, 2023 22:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants