Skip to content

[mlir][sparse] avoid non-perm on sparse tensor convert for new #72459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 16, 2023

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Nov 16, 2023

This avoids seeing non-perm on the convert from COO to non-COO for higher dimensional new operators (viz. reading in BSR).

This is step 1 out of 3 to make sparse_tensor.new work for BSR

This avoids seeing non-perm on the convert from COO to non-COO
for higher dimensional new operators (viz. reading in BSR).

This is step 1 out of 3 to make sparse_tensor.new work for BSR
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Nov 16, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 16, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

Changes

This avoids seeing non-perm on the convert from COO to non-COO for higher dimensional new operators (viz. reading in BSR).

This is step 1 out of 3 to make sparse_tensor.new work for BSR


Full diff: https://github.com/llvm/llvm-project/pull/72459.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+18-6)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 811bdc57ce14fb6..3fe0c551be57a4d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1189,20 +1189,30 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
   LogicalResult matchAndRewrite(NewOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    const auto dstTp = getSparseTensorType(op.getResult());
-    const auto encDst = dstTp.getEncoding();
-    if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0)
+    auto stt = getSparseTensorType(op.getResult());
+    auto enc = stt.getEncoding();
+    if (!stt.hasEncoding() || getCOOStart(enc) == 0)
       return failure();
 
     // Implement the NewOp as follows:
     //   %orderedCoo = sparse_tensor.new %filename
     //   %t = sparse_tensor.convert %orderedCoo
+    // with enveloping reinterpreted_map ops for non-permutations.
+    RankedTensorType dstTp = stt.getRankedTensorType();
     RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
     Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
-    Value convert = rewriter.replaceOpWithNewOp<ConvertOp>(
-        op, dstTp.getRankedTensorType(), cooTensor);
+    Value convert = cooTensor;
+    if (!stt.isPermutation()) { // demap coo, demap dstTp
+      auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
+      convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
+      dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
+    }
+    convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
+    if (!stt.isPermutation()) // remap to original enc
+      convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
+    rewriter.replaceOp(op, convert);
 
-    // Release the ordered COO tensor.
+    // Release the temporary ordered COO tensor.
     rewriter.setInsertionPointAfterValue(convert);
     rewriter.create<DeallocTensorOp>(loc, cooTensor);
 
@@ -1210,6 +1220,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
   }
 };
 
+/// Sparse rewriting rule for the out operator.
 struct OutRewriter : public OpRewritePattern<OutOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(OutOp op,
@@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
                                     primaryTypeFunctionSuffix(eltTp)};
     Value value = genAllocaScalar(rewriter, loc, eltTp);
     ModuleOp module = op->getParentOfType<ModuleOp>();
+
     // For each element in the source tensor, output the element.
     rewriter.create<ForeachOp>(
         loc, src, std::nullopt,

@aartbik aartbik merged commit e8fc282 into llvm:main Nov 16, 2023
@aartbik aartbik deleted the bik branch November 16, 2023 04:47
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…72459)

This avoids seeing non-perm on the convert from COO to non-COO for
higher dimensional new operators (viz. reading in BSR).

This is step 1 out of 3 to make sparse_tensor.new work for BSR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants