Skip to content

[mlir][sparse] code cleanup, remove FIXMEs #73575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
// Reordering.
//

/// [deprecated] Convenience method to translate the given level to the
/// corresponding dimension. Requires: `0 <= l < lvlRank`.
Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);

/// [deprecated] Convenience method to translate the given dimension to
/// the corresponding level. Requires: `0 <= d < dimRank`.
Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
/// Convenience method to translate the given level to the corresponding
/// dimension.
/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
Dimension toDim(SparseTensorEncodingAttr enc, Level l);

/// Convenience method to translate the given dimension to the corresponding
/// level.
/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d);

} // namespace sparse_tensor
} // namespace mlir
Expand Down
37 changes: 11 additions & 26 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,12 @@ SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {

std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
// FIXME: `toOrigDim` is deprecated.
return getStaticDimSliceOffset(toOrigDim(*this, lvl));
return getStaticDimSliceOffset(toDim(*this, lvl));
}

std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
// FIXME: `toOrigDim` is deprecated.
return getStaticDimSliceStride(toOrigDim(*this, lvl));
return getStaticDimSliceStride(toDim(*this, lvl));
}

SmallVector<int64_t>
Expand All @@ -398,10 +396,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,

if (isPermutation()) {
for (unsigned r = 0; r < rank; r++) {
// FIXME: `toOrigDim` and `toStoredDim` are deprecated.
unsigned trans = dir == CrdTransDirectionKind::dim2lvl
? toOrigDim(*this, r)
: toStoredDim(*this, r);
unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
: toLvl(*this, r);
ret.push_back(srcShape[trans]);
}
return ret;
Expand Down Expand Up @@ -925,31 +921,20 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
ordered);
}

// TODO: Remove this definition once all use-sites have been fixed to
// properly handle non-permutations.
Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
Level l) {
Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
if (enc) {
if (const auto dimToLvl = enc.getDimToLvl()) {
assert(enc.isPermutation());
assert(enc.isPermutation() && "Non permutation map not supported");
if (const auto dimToLvl = enc.getDimToLvl())
return dimToLvl.getDimPosition(l);
}
}
return l;
}

// TODO: Remove this definition once all use-sites have been fixed to
// properly handle non-permutations.
Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
Dimension d) {
Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
if (enc) {
if (const auto dimToLvl = enc.getDimToLvl()) {
assert(enc.isPermutation());
auto maybePos =
dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
assert(maybePos.has_value());
return *maybePos;
}
assert(enc.isPermutation() && "Non permutation map not supported");
if (const auto lvlToDim = enc.getLvlToDim())
return lvlToDim.getDimPosition(d);
}
return d;
}
Expand Down
26 changes: 0 additions & 26 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,32 +546,6 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
}
}

Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr enc,
ValueRange dimSizes,
Value valuesBuffer,
Value lvlCoords) {
// Reuse the `lvlCoords` buffer to store the level-sizes.
const Level lvlRank = enc.getLvlRank();
SmallVector<Value> lvlSizes;
lvlSizes.reserve(lvlRank);
for (Level l = 0; l < lvlRank; l++)
// FIXME: `toOrigDim` is deprecated.
lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]);
storeAll(builder, loc, lvlCoords, lvlSizes);
// The memref ReshapeOp requires the sizes buffer to have a static
// shape.
const auto iTp = builder.getIndexType();
const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
// Finally, create the ReshapeOp.
const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
const Type elemTp = getMemRefType(valuesBuffer).getElementType();
const auto resTp = MemRefType::get(resShape, elemTp);
return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
}

TypedValue<BaseMemRefType>
sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
auto tTp = llvm::cast<TensorType>(tensor.getType());
Expand Down
7 changes: 0 additions & 7 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,6 @@ SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size,
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
size_t offsetIdx = 0, Value offsetVal = Value());

/// Reshapes the linear values buffer for an annotated all dense sparse tensor
/// to match the shape of the corresponding dense tensor to support direct
/// access of the buffer through `lvlCoords`.
Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr enc, ValueRange dimSizes,
Value valuesBuffer, Value lvlCoords);

// Generates code to cast a tensor to a memref.
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ enum class SortMask : unsigned {
// The individual mask bits.
kIncludeDenseOutput = 0x1, // b001
kIncludeDenseInput = 0x2, // b010
kIncludeUndef = 0x4, // b100
// The subsets of mask bits.
kIncludeAll = 0x7, // b111
kIncludeDense = 0x3, // b011
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,13 @@ static constexpr unsigned kSliceIterWidth = 3;
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
// FIXME: `toOrigDim` is deprecated
return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
}

static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
// FIXME: `toOrigDim` is deprecated
return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
}

/// Converts a coordinate relative to the slice to the coordinate relative
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,10 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
// computation. Must be ordered from more strict to less strict.
// Ideally (though might not be guaranteed), the earlier a constraint mask
// can be satisfied, the faster the generated kernel will be.
const auto allMasks = {
SortMask::kIncludeAll, SortMask::kIncludeDense,
SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
SortMask::kIncludeUndef, SortMask::kSparseOnly};
const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
SortMask::kIncludeDenseInput,
SortMask::kIncludeDenseOutput,
SortMask::kSparseOnly};
for (const SortMask mask : allMasks) {
order = scheduler.sort(mask);
if (order) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
SmallVector<Value> srcDcvs;
srcDcvs.reserve(srcRank);
for (Dimension d = 0; d < srcRank; d++) {
// FIXME: `toStoredDim` is deprecated
Level lvl = toStoredDim(encSrc, d);
Level lvl = toLvl(encSrc, d);
srcDcvs.push_back(srcLcvs[lvl]);
}

Expand Down Expand Up @@ -766,8 +765,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
SmallVector<Value> srcDcvs;
srcDcvs.reserve(dimRank);
for (Dimension d = 0; d < dimRank; d++) {
// FIXME: `toStoredDim` is deprecated
Level lvl = toStoredDim(encSrc, d);
Level lvl = toLvl(encSrc, d);
srcDcvs.push_back(srcLcvs[lvl]);
}
SmallVector<Value> dstDcvs;
Expand Down Expand Up @@ -872,9 +870,8 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
return failure();

if (stt.isPermutation()) {
// FIXME: `toStoredDim` is deprecated
rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
toStoredDim(stt.getEncoding(), *dim));
toLvl(stt.getEncoding(), *dim));
return success();
}

Expand Down