@@ -679,6 +679,50 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
679
679
}
680
680
};
681
681
682
+ // TODO: use a new SortCOO operation here instead of reusing convert op.
683
+ struct SparseSortCOOConverter : public OpConversionPattern <ConvertOp> {
684
+ using OpConversionPattern::OpConversionPattern;
685
+ LogicalResult
686
+ matchAndRewrite (ConvertOp op, ConvertOpAdaptor adaptor,
687
+ ConversionPatternRewriter &rewriter) const override {
688
+ // Direct conversion should have already been lowered.
689
+ if (!op.isSortCOOConvert ())
690
+ return failure ();
691
+
692
+ Location loc = op.getLoc ();
693
+ MLIRContext *ctx = op.getContext ();
694
+
695
+ SparseTensorType srcStt = getSparseTensorType (op.getSource ());
696
+ SparseTensorType dstStt = getSparseTensorType (op.getDest ());
697
+
698
+ // TODO: This should be verification rules for sort_coo operation.
699
+ assert (dstStt.isAllOrdered () && !srcStt.isAllOrdered () &&
700
+ isUniqueCOOType (srcStt.getRankedTensorType ()) &&
701
+ isUniqueCOOType (dstStt.getRankedTensorType ()));
702
+
703
+ assert (dstStt.hasSameDimToLvl (srcStt));
704
+
705
+ // We don't need a mutable descriptor here as we perform sorting in-place.
706
+ auto nnz = genValMemSize (rewriter, op.getLoc (), adaptor.getSource ());
707
+ auto desc = getDescriptorFromTensorTuple (adaptor.getSource ());
708
+ auto crd = desc.getAOSMemRef ();
709
+ auto val = desc.getValMemRef ();
710
+
711
+ // Otherwise we need another data shuffle and a non-identity map.
712
+ assert (dstStt.hasSameDimToLvl (srcStt));
713
+ auto id = AffineMap::getMultiDimIdentityMap (srcStt.getLvlRank (), ctx);
714
+
715
+ rewriter.create <SortOp>(loc, nnz, crd, ValueRange{val}, id,
716
+ rewriter.getIndexAttr (0 ),
717
+ SparseTensorSortKind::HybridQuickSort);
718
+
719
+ // Since we do in-place sorting, the destinate tensor will have the same set
720
+ // of memrefs as the source tensor.
721
+ rewriter.replaceOp (op, adaptor.getSource ());
722
+ return success ();
723
+ }
724
+ };
725
+
682
726
template <typename Op, StorageSpecifierKind kind>
683
727
class SparseSliceGetterOpConverter : public OpConversionPattern <Op> {
684
728
public:
@@ -1101,6 +1145,9 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1101
1145
LogicalResult
1102
1146
matchAndRewrite (ConvertOp op, OpAdaptor adaptor,
1103
1147
ConversionPatternRewriter &rewriter) const override {
1148
+ if (op.isSortCOOConvert ())
1149
+ return failure ();
1150
+
1104
1151
SparseTensorEncodingAttr encDst = getSparseTensorEncoding (op.getType ());
1105
1152
SparseTensorEncodingAttr encSrc =
1106
1153
getSparseTensorEncoding (op.getSource ().getType ());
@@ -1554,6 +1601,7 @@ void mlir::populateSparseTensorCodegenPatterns(
1554
1601
SparseCastConverter, SparseExtractSliceConverter,
1555
1602
SparseTensorLoadConverter, SparseExpandConverter,
1556
1603
SparseCompressConverter, SparseInsertConverter,
1604
+ SparseSortCOOConverter,
1557
1605
SparseSliceGetterOpConverter<ToSliceOffsetOp,
1558
1606
StorageSpecifierKind::DimOffset>,
1559
1607
SparseSliceGetterOpConverter<ToSliceStrideOp,
0 commit comments