Skip to content

Commit e8fc282

Browse files
authored
[mlir][sparse] avoid non-perm on sparse tensor convert for new (llvm#72459)
This avoids seeing non-perm on the convert from COO to non-COO for higher dimensional new operators (viz. reading in BSR). This is step 1 out of 3 to make sparse_tensor.new work for BSR
1 parent 8404406 commit e8fc282

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,27 +1189,38 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
11891189
LogicalResult matchAndRewrite(NewOp op,
11901190
PatternRewriter &rewriter) const override {
11911191
Location loc = op.getLoc();
1192-
const auto dstTp = getSparseTensorType(op.getResult());
1193-
const auto encDst = dstTp.getEncoding();
1194-
if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0)
1192+
auto stt = getSparseTensorType(op.getResult());
1193+
auto enc = stt.getEncoding();
1194+
if (!stt.hasEncoding() || getCOOStart(enc) == 0)
11951195
return failure();
11961196

11971197
// Implement the NewOp as follows:
11981198
// %orderedCoo = sparse_tensor.new %filename
11991199
// %t = sparse_tensor.convert %orderedCoo
1200+
// with enveloping reinterpreted_map ops for non-permutations.
1201+
RankedTensorType dstTp = stt.getRankedTensorType();
12001202
RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
12011203
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1202-
Value convert = rewriter.replaceOpWithNewOp<ConvertOp>(
1203-
op, dstTp.getRankedTensorType(), cooTensor);
1204+
Value convert = cooTensor;
1205+
if (!stt.isPermutation()) { // demap coo, demap dstTp
1206+
auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1207+
convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1208+
dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1209+
}
1210+
convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1211+
if (!stt.isPermutation()) // remap to original enc
1212+
convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1213+
rewriter.replaceOp(op, convert);
12041214

1205-
// Release the ordered COO tensor.
1215+
// Release the temporary ordered COO tensor.
12061216
rewriter.setInsertionPointAfterValue(convert);
12071217
rewriter.create<DeallocTensorOp>(loc, cooTensor);
12081218

12091219
return success();
12101220
}
12111221
};
12121222

1223+
/// Sparse rewriting rule for the out operator.
12131224
struct OutRewriter : public OpRewritePattern<OutOp> {
12141225
using OpRewritePattern::OpRewritePattern;
12151226
LogicalResult matchAndRewrite(OutOp op,
@@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
12501261
primaryTypeFunctionSuffix(eltTp)};
12511262
Value value = genAllocaScalar(rewriter, loc, eltTp);
12521263
ModuleOp module = op->getParentOfType<ModuleOp>();
1264+
12531265
// For each element in the source tensor, output the element.
12541266
rewriter.create<ForeachOp>(
12551267
loc, src, std::nullopt,

0 commit comments

Comments
 (0)