Skip to content

Commit dda3dc5

Browse files
authored
[mlir][sparse] simplify ConvertOp rewriting rules (#68350)
Canonicalize complex convertOp into multiple stages, such that it can either be done by a direct conversion or by sorting.
1 parent 12b87f6 commit dda3dc5

File tree

7 files changed

+277
-255
lines changed

7 files changed

+277
-255
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,17 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
195195
```
196196

197197
}];
198+
199+
let extraClassDeclaration = [{
200+
// Whether the convert can be done by a single step (either a sort or a foreach),
201+
// or it would require a tmp buffer (sort, then foreach).
202+
bool directConvertable();
203+
204+
// Whether the convert is actually a sort coo
205+
// TODO: The method will be removed when sort_coo operation is introduced.
206+
bool isSortCOOConvert();
207+
}];
208+
198209
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
199210
let hasFolder = 1;
200211
let hasVerifier = 1;

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,44 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
10661066
return {};
10671067
}
10681068

1069+
bool ConvertOp::directConvertable() {
1070+
if (isSortCOOConvert())
1071+
return false;
1072+
1073+
SparseTensorType srcStt = getSparseTensorType(getSource());
1074+
SparseTensorType dstStt = getSparseTensorType(getDest());
1075+
1076+
// We can always directly convert to unordered sparse tensor or dense tensor
1077+
// since dense tensor support random access.
1078+
if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1079+
return true;
1080+
1081+
if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1082+
srcStt.hasSameDimToLvl(dstStt)) {
1083+
return true;
1084+
}
1085+
1086+
// Source and dest tensors are ordered in different ways. We only do direct
1087+
// dense to sparse conversion when the dense input is defined by a sparse
1088+
// constant. Note that we can theoretically always directly convert from dense
1089+
// inputs by rotating dense loops but it leads to bad cache locality and hurt
1090+
// performance.
1091+
if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1092+
if (isa<SparseElementsAttr>(constOp.getValue()))
1093+
return true;
1094+
1095+
return false;
1096+
}
1097+
1098+
bool ConvertOp::isSortCOOConvert() {
1099+
// TODO: we should instead use a different sort_coo operation to handle
1100+
// the conversion between COOs (but with different ordering).
1101+
return isUniqueCOOType(getSource().getType()) &&
1102+
isUniqueCOOType(getDest().getType()) &&
1103+
!getSparseTensorType(getSource()).isAllOrdered() &&
1104+
getSparseTensorType(getDest()).isAllOrdered();
1105+
}
1106+
10691107
LogicalResult ToPositionsOp::verify() {
10701108
auto e = getSparseTensorEncoding(getTensor().getType());
10711109
if (failed(lvlIsInBounds(getLevel(), getTensor())))

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,50 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
679679
}
680680
};
681681

682+
// TODO: use a new SortCOO operation here instead of reusing convert op.
683+
struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
684+
using OpConversionPattern::OpConversionPattern;
685+
LogicalResult
686+
matchAndRewrite(ConvertOp op, ConvertOpAdaptor adaptor,
687+
ConversionPatternRewriter &rewriter) const override {
688+
// Direct conversion should have already been lowered.
689+
if (!op.isSortCOOConvert())
690+
return failure();
691+
692+
Location loc = op.getLoc();
693+
MLIRContext *ctx = op.getContext();
694+
695+
SparseTensorType srcStt = getSparseTensorType(op.getSource());
696+
SparseTensorType dstStt = getSparseTensorType(op.getDest());
697+
698+
// TODO: This should be verification rules for sort_coo operation.
699+
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
700+
isUniqueCOOType(srcStt.getRankedTensorType()) &&
701+
isUniqueCOOType(dstStt.getRankedTensorType()));
702+
703+
assert(dstStt.hasSameDimToLvl(srcStt));
704+
705+
// We don't need a mutable descriptor here as we perform sorting in-place.
706+
auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getSource());
707+
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
708+
auto crd = desc.getAOSMemRef();
709+
auto val = desc.getValMemRef();
710+
711+
// Otherwise we need another data shuffle and a non-identity map.
712+
assert(dstStt.hasSameDimToLvl(srcStt));
713+
auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
714+
715+
rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
716+
rewriter.getIndexAttr(0),
717+
SparseTensorSortKind::HybridQuickSort);
718+
719+
// Since we do in-place sorting, the destinate tensor will have the same set
720+
// of memrefs as the source tensor.
721+
rewriter.replaceOp(op, adaptor.getSource());
722+
return success();
723+
}
724+
};
725+
682726
template <typename Op, StorageSpecifierKind kind>
683727
class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
684728
public:
@@ -1101,6 +1145,9 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
11011145
LogicalResult
11021146
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
11031147
ConversionPatternRewriter &rewriter) const override {
1148+
if (op.isSortCOOConvert())
1149+
return failure();
1150+
11041151
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
11051152
SparseTensorEncodingAttr encSrc =
11061153
getSparseTensorEncoding(op.getSource().getType());
@@ -1554,6 +1601,7 @@ void mlir::populateSparseTensorCodegenPatterns(
15541601
SparseCastConverter, SparseExtractSliceConverter,
15551602
SparseTensorLoadConverter, SparseExpandConverter,
15561603
SparseCompressConverter, SparseInsertConverter,
1604+
SparseSortCOOConverter,
15571605
SparseSliceGetterOpConverter<ToSliceOffsetOp,
15581606
StorageSpecifierKind::DimOffset>,
15591607
SparseSliceGetterOpConverter<ToSliceStrideOp,

0 commit comments

Comments
 (0)