Skip to content

[mlir][sparse] simplify ConvertOp rewriting rules #68350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,17 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
```

}];

let extraClassDeclaration = [{
// Whether the convert can be done by a single step (either a sort or a foreach),
// or it would require a tmp buffer (sort, then foreach).
bool directConvertable();

// Whether the convert is actually a sort coo
// TODO: The method will be removed when sort_coo operation is introduced.
bool isSortCOOConvert();
}];

let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasFolder = 1;
let hasVerifier = 1;
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,44 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
return {};
}

bool ConvertOp::directConvertable() {
if (isSortCOOConvert())
return false;

SparseTensorType srcStt = getSparseTensorType(getSource());
SparseTensorType dstStt = getSparseTensorType(getDest());

// We can always directly convert to unordered sparse tensor or dense tensor
// since dense tensor support random access.
if (dstStt.isAllDense() || !dstStt.isAllOrdered())
return true;

if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
srcStt.hasSameDimToLvl(dstStt)) {
return true;
}

// Source and dest tensors are ordered in different ways. We only do direct
// dense to sparse conversion when the dense input is defined by a sparse
// constant. Note that we can theoretically always directly convert from dense
// inputs by rotating dense loops but it leads to bad cache locality and hurt
// performance.
if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
if (isa<SparseElementsAttr>(constOp.getValue()))
return true;

return false;
}

bool ConvertOp::isSortCOOConvert() {
// TODO: we should instead use a different sort_coo operation to handle
// the conversion between COOs (but with different ordering).
return isUniqueCOOType(getSource().getType()) &&
isUniqueCOOType(getDest().getType()) &&
!getSparseTensorType(getSource()).isAllOrdered() &&
getSparseTensorType(getDest()).isAllOrdered();
}

LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,50 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
}
};

// TODO: use a new SortCOO operation here instead of reusing convert op.
struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ConvertOp op, ConvertOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Direct conversion should have already been lowered.
if (!op.isSortCOOConvert())
return failure();

Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();

SparseTensorType srcStt = getSparseTensorType(op.getSource());
SparseTensorType dstStt = getSparseTensorType(op.getDest());

// TODO: This should be verification rules for sort_coo operation.
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
isUniqueCOOType(srcStt.getRankedTensorType()) &&
isUniqueCOOType(dstStt.getRankedTensorType()));

assert(dstStt.hasSameDimToLvl(srcStt));

// We don't need a mutable descriptor here as we perform sorting in-place.
auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getSource());
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto crd = desc.getAOSMemRef();
auto val = desc.getValMemRef();

// Otherwise we need another data shuffle and a non-identity map.
assert(dstStt.hasSameDimToLvl(srcStt));
auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);

rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
rewriter.getIndexAttr(0),
SparseTensorSortKind::HybridQuickSort);

// Since we do in-place sorting, the destinate tensor will have the same set
// of memrefs as the source tensor.
rewriter.replaceOp(op, adaptor.getSource());
return success();
}
};

template <typename Op, StorageSpecifierKind kind>
class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
public:
Expand Down Expand Up @@ -1101,6 +1145,9 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
LogicalResult
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.isSortCOOConvert())
return failure();

SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
getSparseTensorEncoding(op.getSource().getType());
Expand Down Expand Up @@ -1554,6 +1601,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
SparseSortCOOConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
Expand Down
Loading