Skip to content

Commit 298412b

Browse files
authored
[mlir][sparse] setup SparseIterator to help generating code to traverse a sparse tensor level. (#78345)
1 parent 48bbd76 commit 298412b

19 files changed

+2266
-2654
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
11261126
}
11271127

11281128
Value vals = loopEmitter.getValBuffer()[0];
1129-
Value pos = loopEmitter.getPosits()[0].back();
1129+
Value pos = loopEmitter.getValPosits(0);
11301130
// Loads the value from sparse tensor using position-index;
11311131
// loads the value from dense tensor using coords.
11321132
Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
@@ -1148,17 +1148,17 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
11481148
SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
11491149
rewriter.eraseOp(srcBlock->getTerminator());
11501150

1151-
// Inline body.
1152-
if (!reducValue.empty()) {
1153-
rewriter.mergeBlocks(srcBlock, rewriter.getBlock(), args);
1154-
} else {
1155-
// This is annoying, since scf.for inserts a implicit yield op when
1156-
// there is no reduction variable upon creation, in this case we need to
1157-
// merge the block *before* the yield op.
1158-
rewriter.inlineBlockBefore(srcBlock, &*rewriter.getInsertionPoint(),
1159-
args);
1151+
Operation &last = rewriter.getBlock()->back();
1152+
if (llvm::isa<scf::YieldOp>(last)) {
1153+
// Because `scf.for` inserts an implicit yield op when there is no
1154+
// reduction variable upon creation, we reset the insertion point such
1155+
// that the block is inlined before *before* the yield op.
1156+
rewriter.setInsertionPoint(&last);
11601157
}
11611158

1159+
rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1160+
rewriter.getInsertionPoint(), args);
1161+
rewriter.setInsertionPointToEnd(rewriter.getBlock());
11621162
for (Level l = 0; l < lvlRank; l++) {
11631163
// Link the reduction chain. Note that loop emitter update the reducValue
11641164
// in place.

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
354354
const auto stt = getSparseTensorType(t->get());
355355
if (stt.hasEncoding()) {
356356
// For sparse tensors we only push the last-level's position onto `args`.
357-
const auto pos = env.emitter().getPosits()[tid].back();
357+
const auto pos = env.emitter().getValPosits(tid);
358358
assert(pos);
359359
args.push_back(pos);
360360
} else {
@@ -815,8 +815,7 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
815815
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
816816
// Construct while-loop with a parameter for each index.
817817
return env.emitter().enterCoIterationOverTensorsAtLvls(
818-
builder, env.op().getLoc(), tidLvls, reduc, tryParallel,
819-
/*genDedup=*/true, needsUniv);
818+
builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
820819
});
821820
assert(loop);
822821
return loop;
@@ -894,7 +893,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
894893
if (isCompressedLT(lt) || isSingletonLT(lt) ||
895894
isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
896895
assert(lvl.has_value());
897-
const Value crd = env.emitter().getCoords()[tid][*lvl];
896+
const Value crd = env.emitter().getCoord(tid, *lvl);
898897
const Value lvar = env.getLoopVar(curr);
899898
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
900899
crd, lvar);
@@ -1032,10 +1031,14 @@ static bool getAllTidLvlsInLatPoints(
10321031
});
10331032

10341033
if (isDenseLT(env.lt(outTid, curr))) {
1035-
// Note that we generate dense indices of the output tensor
1036-
// unconditionally, since they may not appear in the lattice, but may be
1037-
// needed for linearized env.
1038-
callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
1034+
auto stt = getSparseTensorType(env.op().getOutputs().front());
1035+
// Note that we generate dense indices of the output tensor unconditionally,
1036+
// since they may not appear in the lattice, but may be needed for
1037+
// linearized env.
1038+
// TODO: we should avoid introducing corner cases for all-dense sparse
1039+
// tensors.
1040+
if (stt.hasEncoding() && stt.isAllDense())
1041+
callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
10391042
}
10401043

10411044
if (numloopCond == 0) {
@@ -1064,6 +1067,11 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
10641067

10651068
SmallVector<TensorLevel> tidLvls;
10661069
getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1070+
// TODO: remove this! The same tensor level might be added for multiple
1071+
// times due to the special handling for all-dense "sparse" output tensor
1072+
// (see L1038).
1073+
if (llvm::find(tidLvls, tl) != tidLvls.end())
1074+
return;
10671075
tidLvls.emplace_back(tl);
10681076
});
10691077

@@ -1096,7 +1104,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
10961104
for (Level l = startLvl; l < lvlRank; l++) {
10971105
AffineExpr lvlExpr = lvlExprs[l];
10981106
if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
1099-
env.emitter().genDenseAffineAddress(
1107+
env.emitter().locateLvlAtAffineAddress(
11001108
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
11011109
else
11021110
return; // break on first non-dense non-constant level
@@ -1145,7 +1153,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
11451153
Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
11461154
Location loc = env.op().getLoc();
11471155
for (auto [tidLvl, exp] : affineTidLvls) {
1148-
env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp);
1156+
env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
11491157
}
11501158

11511159
// Until now, we have entered every <tid, lvl> pair in {cond, extra,

0 commit comments

Comments
 (0)