diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index 5e53dbe1cc283..029ecb0708941 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -563,11 +563,14 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module, // p = (lo+hi)/2 // pivot index // i = lo // j = hi-1 -// while (i < j) do { +// while (true) do { // while (xs[i] < xs[p]) i ++; // i_eq = (xs[i] == xs[p]); // while (xs[j] > xs[p]) j --; // j_eq = (xs[j] == xs[p]); +// +// if (i >= j) return j + 1; +// // if (i < j) { // swap(xs[i], xs[j]) // if (i == p) { @@ -581,8 +584,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module, // } // } // } -// return p -// } +// } static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo, uint32_t nTrailingP = 0) { @@ -605,22 +607,22 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module, Value i = lo; Value j = builder.create(loc, hi, c1); createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args); - SmallVector operands{i, j, p}; // Exactly three values. - SmallVector types{i.getType(), j.getType(), p.getType()}; + Value trueVal = constantI1(builder, loc, true); // The value for while (true) + SmallVector operands{i, j, p, trueVal}; // Exactly four values. + SmallVector types{i.getType(), j.getType(), p.getType(), + trueVal.getType()}; scf::WhileOp whileOp = builder.create(loc, types, operands); // The before-region of the WhileOp. - Block *before = - builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc}); + Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, + {loc, loc, loc, loc}); builder.setInsertionPointToEnd(before); - Value cond = builder.create(loc, arith::CmpIPredicate::ult, - before->getArgument(0), - before->getArgument(1)); - builder.create(loc, cond, before->getArguments()); + builder.create(loc, before->getArgument(3), + before->getArguments()); // The after-region of the WhileOp. Block *after = - builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc}); + builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc}); builder.setInsertionPointToEnd(after); i = after->getArgument(0); j = after->getArgument(1); @@ -637,7 +639,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module, j = jresult; // If i < j: - cond = builder.create(loc, arith::CmpIPredicate::ult, i, j); + Value cond = + builder.create(loc, arith::CmpIPredicate::ult, i, j); scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); SmallVector swapOperands{i, j}; @@ -675,11 +678,15 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module, builder.setInsertionPointAfter(ifOp2); builder.create( loc, - ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)}); + ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0), + /*cont=*/constantI1(builder, loc, true)}); - // False branch for if i < j: + // False branch for if i < j (i.e., i >= j): builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, ValueRange{i, j, p}); + p = builder.create(loc, j, + constantOne(builder, loc, j.getType())); + builder.create( + loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)}); // Return for the whileOp. builder.setInsertionPointAfter(ifOp); @@ -927,6 +934,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, Location loc = func.getLoc(); Value lo = args[loIdx]; Value hi = args[hiIdx]; + SmallVector types(2, lo.getType()); // Only two types. + FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc); @@ -935,14 +944,25 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, TypeRange{IndexType::get(context)}, args.drop_back(nTrailingP)) .getResult(0); - Value pP1 = - builder.create(loc, p, constantIndex(builder, loc, 1)); + Value lenLow = builder.create(loc, p, lo); Value lenHigh = builder.create(loc, hi, p); + // Partition already sorts array with len <= 2 + Value c2 = constantIndex(builder, loc, 2); + Value len = builder.create(loc, hi, lo); + Value lenGtTwo = + builder.create(loc, arith::CmpIPredicate::ugt, len, c2); + scf::IfOp ifLenGtTwo = + builder.create(loc, types, lenGtTwo, /*else=*/true); + builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front()); + // Returns an empty range to mark the entire region is fully sorted. + builder.create(loc, ValueRange{lo, lo}); + + // Else len > 2, need recursion. + builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front()); Value cond = builder.create(loc, arith::CmpIPredicate::ule, lenLow, lenHigh); - SmallVector types(2, lo.getType()); // Only two types. scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); Value c0 = constantIndex(builder, loc, 0); @@ -961,14 +981,17 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, // the bigger partition to be processed by the enclosed while-loop. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); mayRecursion(lo, p, lenLow); - builder.create(loc, ValueRange{pP1, hi}); + builder.create(loc, ValueRange{p, hi}); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - mayRecursion(pP1, hi, lenHigh); + mayRecursion(p, hi, lenHigh); builder.create(loc, ValueRange{lo, p}); builder.setInsertionPointAfter(ifOp); - return std::make_pair(ifOp.getResult(0), ifOp.getResult(1)); + builder.create(loc, ifOp.getResults()); + + builder.setInsertionPointAfter(ifLenGtTwo); + return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1)); } /// Creates a function to perform insertion sort on the values in the range of diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir index 170f851138f82..0036bd5c3310b 100644 --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -75,343 +75,9 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f // ----- -// CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index( -// CHECK-SAME: %[[VAL_0:.*0]]: index, -// CHECK-SAME: %[[VAL_1:.*1]]: index, -// CHECK-SAME: %[[VAL_2:.*2]]: memref, -// CHECK-SAME: %[[VAL_3:.*3]]: memref, -// CHECK-SAME: %[[VAL_4:.*4]]: memref) -> index { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1000 -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant -1 -// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] -// CHECK: %[[VAL_9:.*]] = arith.shrui %[[VAL_8]], %[[VAL_5]] -// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_1]], %[[VAL_5]] -// CHECK: %[[VAL_11:.*]] = arith.subi %[[VAL_1]], %[[VAL_0]] -// CHECK: %[[VAL_12:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_6]] -// CHECK: scf.if %[[VAL_12]] { -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_14]] -// CHECK: scf.if %[[VAL_15]] { -// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_16]], %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_19]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_18]], %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_21]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_20]], %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: } -// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] -// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_24:.*]] = arith.cmpi ult, %[[VAL_22]], %[[VAL_23]] -// CHECK: scf.if %[[VAL_24]] { -// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] -// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_26]], %[[VAL_2]]{{\[}}%[[VAL_10]]] -// CHECK: memref.store %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_10]]] -// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_28]], %[[VAL_3]]{{\[}}%[[VAL_10]]] -// CHECK: memref.store %[[VAL_27]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_10]]] -// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_30]], %[[VAL_4]]{{\[}}%[[VAL_10]]] -// CHECK: memref.store %[[VAL_29]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_33:.*]] = arith.cmpi ult, %[[VAL_31]], %[[VAL_32]] -// CHECK: scf.if %[[VAL_33]] { -// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_35]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_34]], %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_37]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_36]], %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_39]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_38]], %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: } -// CHECK: } -// CHECK: } else { -// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_9]], %[[VAL_1]] -// CHECK: %[[VAL_41:.*]] = arith.shrui %[[VAL_40]], %[[VAL_5]] -// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_44:.*]] = arith.cmpi ult, %[[VAL_42]], %[[VAL_43]] -// CHECK: scf.if %[[VAL_44]] { -// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_46]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_45]], %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_48]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_47]], %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_49:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_50]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_49]], %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: } -// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_51]], %[[VAL_51]] -// CHECK: scf.if %[[VAL_52]] { -// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_53]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_53]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_54]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_54]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_55]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_55]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_57:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_58:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_57]] -// CHECK: scf.if %[[VAL_58]] { -// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_60]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_59]], %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_62]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_61]], %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_64]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_63]], %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: } -// CHECK: } -// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_67:.*]] = arith.cmpi ult, %[[VAL_65]], %[[VAL_66]] -// CHECK: scf.if %[[VAL_67]] { -// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_69:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_69]], %[[VAL_2]]{{\[}}%[[VAL_41]]] -// CHECK: memref.store %[[VAL_68]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_71]], %[[VAL_3]]{{\[}}%[[VAL_41]]] -// CHECK: memref.store %[[VAL_70]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_73]], %[[VAL_4]]{{\[}}%[[VAL_41]]] -// CHECK: memref.store %[[VAL_72]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_75:.*]] = arith.cmpi ult, %[[VAL_74]], %[[VAL_74]] -// CHECK: scf.if %[[VAL_75]] { -// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_76]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_76]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_77:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_77]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_77]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_78]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_78]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_81:.*]] = arith.cmpi ult, %[[VAL_79]], %[[VAL_80]] -// CHECK: scf.if %[[VAL_81]] { -// CHECK: %[[VAL_82:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_83:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_83]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_82]], %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_84:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_85:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_85]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_84]], %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_86:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_87:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_87]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_86]], %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: %[[VAL_88:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] -// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_90:.*]] = arith.cmpi ult, %[[VAL_88]], %[[VAL_89]] -// CHECK: scf.if %[[VAL_90]] { -// CHECK: %[[VAL_91:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] -// CHECK: %[[VAL_92:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]] -// CHECK: memref.store %[[VAL_92]], %[[VAL_2]]{{\[}}%[[VAL_10]]] -// CHECK: memref.store %[[VAL_91]], %[[VAL_2]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_93:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_10]]] -// CHECK: %[[VAL_94:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] -// CHECK: memref.store %[[VAL_94]], %[[VAL_3]]{{\[}}%[[VAL_10]]] -// CHECK: memref.store %[[VAL_93]], %[[VAL_3]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_95:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_10]]] -// CHECK: %[[VAL_96:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] -// CHECK: memref.store %[[VAL_96]], %[[VAL_4]]{{\[}}%[[VAL_10]]] -// CHECK: memref.store %[[VAL_95]], %[[VAL_4]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_97:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_98:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_99:.*]] = arith.cmpi ult, %[[VAL_97]], %[[VAL_98]] -// CHECK: scf.if %[[VAL_99]] { -// CHECK: %[[VAL_100:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_101:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_101]], %[[VAL_2]]{{\[}}%[[VAL_41]]] -// CHECK: memref.store %[[VAL_100]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_102:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_103:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_103]], %[[VAL_3]]{{\[}}%[[VAL_41]]] -// CHECK: memref.store %[[VAL_102]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_104:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] -// CHECK: %[[VAL_105:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_105]], %[[VAL_4]]{{\[}}%[[VAL_41]]] -// CHECK: memref.store %[[VAL_104]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_106:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_107:.*]] = arith.cmpi ult, %[[VAL_106]], %[[VAL_106]] -// CHECK: scf.if %[[VAL_107]] { -// CHECK: %[[VAL_108:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_108]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_108]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_109:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_109]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_109]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_110:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_110]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_110]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_111:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_112:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_113:.*]] = arith.cmpi ult, %[[VAL_111]], %[[VAL_112]] -// CHECK: scf.if %[[VAL_113]] { -// CHECK: %[[VAL_114:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_115:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_115]], %[[VAL_2]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_114]], %[[VAL_2]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_116:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_117:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_117]], %[[VAL_3]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_116]], %[[VAL_3]]{{\[}}%[[VAL_0]]] -// CHECK: %[[VAL_118:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: %[[VAL_119:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: memref.store %[[VAL_119]], %[[VAL_4]]{{\[}}%[[VAL_9]]] -// CHECK: memref.store %[[VAL_118]], %[[VAL_4]]{{\[}}%[[VAL_0]]] -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: %[[VAL_120:.*]]:3 = scf.while (%[[VAL_121:.*]] = %[[VAL_0]], %[[VAL_122:.*]] = %[[VAL_10]], %[[VAL_123:.*]] = %[[VAL_9]]) -// CHECK: %[[VAL_124:.*]] = arith.cmpi ult, %[[VAL_121]], %[[VAL_122]] -// CHECK: scf.condition(%[[VAL_124]]) %[[VAL_121]], %[[VAL_122]], %[[VAL_123]] -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_125:.*]]: index, %[[VAL_126:.*]]: index, %[[VAL_127:.*]]: index) -// CHECK: %[[VAL_128:.*]] = scf.while (%[[VAL_129:.*]] = %[[VAL_125]]) -// CHECK: %[[VAL_130:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_129]]] -// CHECK: %[[VAL_131:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]] -// CHECK: %[[VAL_132:.*]] = arith.cmpi ult, %[[VAL_130]], %[[VAL_131]] -// CHECK: scf.condition(%[[VAL_132]]) %[[VAL_129]] -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_133:.*]]: index): -// CHECK: %[[VAL_134:.*]] = arith.addi %[[VAL_133]], %[[VAL_5]] -// CHECK: scf.yield %[[VAL_134]] -// CHECK: } -// CHECK: %[[VAL_135:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_136:.*]]] -// CHECK: %[[VAL_137:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]] -// CHECK: %[[VAL_138:.*]] = arith.cmpi eq, %[[VAL_135]], %[[VAL_137]] -// CHECK: %[[VAL_139:.*]] = scf.while (%[[VAL_140:.*]] = %[[VAL_126]]) -// CHECK: %[[VAL_141:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]] -// CHECK: %[[VAL_142:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_140]]] -// CHECK: %[[VAL_143:.*]] = arith.cmpi ult, %[[VAL_141]], %[[VAL_142]] -// CHECK: scf.condition(%[[VAL_143]]) %[[VAL_140]] -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_144:.*]]: index): -// CHECK: %[[VAL_145:.*]] = arith.addi %[[VAL_144]], %[[VAL_7]] -// CHECK: scf.yield %[[VAL_145]] -// CHECK: } -// CHECK: %[[VAL_146:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_147:.*]]] -// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_137]] -// CHECK: %[[VAL_150:.*]] = arith.cmpi ult, %[[VAL_136]], %[[VAL_147]] -// CHECK: %[[VAL_151:.*]]:3 = scf.if %[[VAL_150]] -// CHECK: %[[VAL_152:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_136]]] -// CHECK: %[[VAL_153:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_147]]] -// CHECK: memref.store %[[VAL_153]], %[[VAL_2]]{{\[}}%[[VAL_136]]] -// CHECK: memref.store %[[VAL_152]], %[[VAL_2]]{{\[}}%[[VAL_147]]] -// CHECK: %[[VAL_154:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_136]]] -// CHECK: %[[VAL_155:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_147]]] -// CHECK: memref.store %[[VAL_155]], %[[VAL_3]]{{\[}}%[[VAL_136]]] -// CHECK: memref.store %[[VAL_154]], %[[VAL_3]]{{\[}}%[[VAL_147]]] -// CHECK: %[[VAL_156:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_136]]] -// CHECK: %[[VAL_157:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_147]]] -// CHECK: memref.store %[[VAL_157]], %[[VAL_4]]{{\[}}%[[VAL_136]]] -// CHECK: memref.store %[[VAL_156]], %[[VAL_4]]{{\[}}%[[VAL_147]]] -// CHECK: %[[VAL_158:.*]] = arith.cmpi eq, %[[VAL_136]], %[[VAL_127]] -// CHECK: %[[VAL_159:.*]] = scf.if %[[VAL_158]] -// CHECK: scf.yield %[[VAL_147]] -// CHECK: } else { -// CHECK: %[[VAL_160:.*]] = arith.cmpi eq, %[[VAL_147]], %[[VAL_127]] -// CHECK: %[[VAL_161:.*]] = arith.select %[[VAL_160]], %[[VAL_136]], %[[VAL_127]] -// CHECK: scf.yield %[[VAL_161]] -// CHECK: } -// CHECK: %[[VAL_162:.*]] = arith.andi %[[VAL_138]], %[[VAL_149]] : i1 -// CHECK: %[[VAL_163:.*]]:2 = scf.if %[[VAL_162]] -// CHECK: %[[VAL_164:.*]] = arith.addi %[[VAL_136]], %[[VAL_5]] -// CHECK: %[[VAL_165:.*]] = arith.subi %[[VAL_147]], %[[VAL_5]] -// CHECK: scf.yield %[[VAL_164]], %[[VAL_165]] -// CHECK: } else { -// CHECK: scf.yield %[[VAL_136]], %[[VAL_147]] -// CHECK: } -// CHECK: scf.yield %[[VAL_166:.*]]#0, %[[VAL_166]]#1, %[[VAL_167:.*]] -// CHECK: } else { -// CHECK: scf.yield %[[VAL_136]], %[[VAL_147]], %[[VAL_127]] -// CHECK: } -// CHECK: scf.yield %[[VAL_168:.*]]#0, %[[VAL_168]]#1, %[[VAL_168]]#2 -// CHECK: } -// CHECK: return %[[VAL_169:.*]]#2 -// CHECK: } - -// CHECK-LABEL: func.func private @_sparse_qsort_1_i8_f32_index( -// CHECK-SAME: %[[L:arg0]]: index, -// CHECK-SAME: %[[H:.*]]: index, -// CHECK-SAME: %[[X0:.*]]: memref, -// CHECK-SAME: %[[Y0:.*]]: memref, -// CHECK-SAME: %[[Y1:.*]]: memref) { -// CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK: scf.while (%[[L2:.*]] = %[[L]], %[[H2:.*]] = %[[H]]) -// CHECK: %[[Lb:.*]] = arith.addi %[[L2]], %[[C1]] -// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H2]] -// CHECK: scf.condition(%[[COND]]) %[[L2]], %[[H2]] -// CHECK: } do { -// CHECK: ^bb0(%[[L3:.*]]: index, %[[H3:.*]]: index) -// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L3]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]]) -// CHECK: %[[PP1:.*]] = arith.addi %[[P]], %[[C1]] : index -// CHECK: %[[LenL:.*]] = arith.subi %[[P]], %[[L3]] -// CHECK: %[[LenH:.*]] = arith.subi %[[H3]], %[[P]] -// CHECK: %[[Cmp:.*]] = arith.cmpi ule, %[[LenL]], %[[LenH]] -// CHECK: %[[L4:.*]] = arith.select %[[Cmp]], %[[PP1]], %[[L3]] -// CHECK: %[[H4:.*]] = arith.select %[[Cmp]], %[[H3]], %[[P]] -// CHECK: scf.if %[[Cmp]] -// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[L3]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]]) -// CHECK: else -// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[PP1]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]]) -// CHECK: scf.yield %[[L4]], %[[H4]] -// CHECK: } -// CHECK: return -// CHECK: } - -// CHECK-LABEL: func.func @sparse_sort_1d2v_quick( -// CHECK-SAME: %[[N:.*]]: index, -// CHECK-SAME: %[[X0:.*]]: memref<10xi8>, -// CHECK-SAME: %[[Y0:.*]]: memref, -// CHECK-SAME: %[[Y1:.*]]: memref<10xindex>) -> (memref<10xi8>, memref, memref<10xindex>) { -// CHECK: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref -// CHECK: %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref -// CHECK: call @_sparse_qsort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]]) -// CHECK: return %[[X0]], %[[Y0]], %[[Y1]] -// CHECK: } +// CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index +// CHECK-LABEL: func.func private @_sparse_qsort_1_i8_f32_index +// CHECK-LABEL: func.func @sparse_sort_1d2v_quick func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xi8>, memref, memref<10xindex>) { sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref, memref<10xindex> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir index c3bdc30e355b1..ca5dd00d02aff 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir @@ -94,14 +94,14 @@ module { %y1 = memref.cast %y1s : memref<7xi32> to memref // Sort "parallel arrays". - // CHECK: ( 1, 1, 2, 5, 10 ) - // CHECK: ( 3, 3, 1, 10, 1 ) - // CHECK: ( 9, 9, 4, 7, 2 ) - // CHECK: ( 7, 8, 10, 9, 6 ) - // CHECK: ( 7, 4, 7, 9, 5 ) - call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1) + // CHECK: ( 1, 1, 3, 3, 10 ) + // CHECK: ( 2, 10, 1, 1, 5 ) + // CHECK: ( 4, 2, 9, 9, 7 ) + // CHECK: ( 10, 6, 7, 8, 9 ) + // CHECK: ( 7, 5, 7, 4, 9 ) + call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3) : (memref>, i32, i32, i32, i32, i32) -> () - call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3) + call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1) : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9) : (memref>, i32, i32, i32, i32, i32) -> () @@ -122,14 +122,14 @@ module { %y1v = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32> vector.print %y1v : vector<5xi32> // Stable sort. - // CHECK: ( 1, 1, 2, 5, 10 ) - // CHECK: ( 3, 3, 1, 10, 1 ) - // CHECK: ( 9, 9, 4, 7, 2 ) - // CHECK: ( 8, 7, 10, 9, 6 ) - // CHECK: ( 4, 7, 7, 9, 5 ) - call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1) + // CHECK: ( 1, 1, 3, 3, 10 ) + // CHECK: ( 2, 10, 1, 1, 5 ) + // CHECK: ( 4, 2, 9, 9, 7 ) + // CHECK: ( 10, 6, 8, 7, 9 ) + // CHECK: ( 7, 5, 4, 7, 9 ) + call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3) : (memref>, i32, i32, i32, i32, i32) -> () - call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3) + call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1) : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9) : (memref>, i32, i32, i32, i32, i32) -> () @@ -150,14 +150,14 @@ module { %y1v2 = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32> vector.print %y1v2 : vector<5xi32> // Heap sort. - // CHECK: ( 1, 1, 2, 5, 10 ) - // CHECK: ( 3, 3, 1, 10, 1 ) - // CHECK: ( 9, 9, 4, 7, 2 ) - // CHECK: ( 7, 8, 10, 9, 6 ) - // CHECK: ( 7, 4, 7, 9, 5 ) - call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1) + // CHECK: ( 1, 1, 3, 3, 10 ) + // CHECK: ( 2, 10, 1, 1, 5 ) + // CHECK: ( 4, 2, 9, 9, 7 ) + // CHECK: ( 10, 6, 8, 7, 9 ) + // CHECK: ( 7, 5, 4, 7, 9 ) + call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3) : (memref>, i32, i32, i32, i32, i32) -> () - call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3) + call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1) : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9) : (memref>, i32, i32, i32, i32, i32) -> ()