Skip to content

[mlir][sparse] add merger support on Batch LevelType. #83186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,16 +333,28 @@ struct LevelType {
return lvlBits & static_cast<uint64_t>(p);
}

/// Check if the `LevelType` is considered to be sparse.
constexpr bool hasSparseSemantic() const {
return isa<LevelFormat::Compressed, LevelFormat::Singleton,
LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
}

/// Check if the `LevelType` is considered to be dense-like.
constexpr bool hasDenseSemantic() const {
return isa<LevelFormat::Dense, LevelFormat::Batch>();
}

/// Check if the `LevelType` needs positions array.
constexpr bool isWithPosLT() const {
return isa<LevelFormat::Compressed>() ||
isa<LevelFormat::LooseCompressed>();
assert(!isa<LevelFormat::Undef>());
return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
}

/// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT() const {
assert(!isa<LevelFormat::Undef>());
// All sparse levels has coordinate array.
return !isa<LevelFormat::Dense, LevelFormat::Batch>();
return hasSparseSemantic();
}

std::string toMLIRString() const {
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,7 @@ class Merger {
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
if (isLvlWithNonTrivialIdxExp(b)) {
auto lt = getLoopDependentLevelType(b);
return isCompressedLT(lt) || isSingletonLT(lt) ||
isLooseCompressedLT(lt) || isNOutOfMLT(lt);
return lt.hasSparseSemantic();
}
return false;
}
Expand Down
8 changes: 3 additions & 5 deletions mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Starts resetting from a dense level, so that the first bit (if kept)
// is not undefined level-type.
for (unsigned b = 0; b < be; b++) {
if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) {
if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
offset = be - b - 1; // relative to the end
break;
}
Expand All @@ -489,8 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Slice on dense level has `locate` property as well, and can be optimized.
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto lt = getLvlType(b);
if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
!isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
if (!lt.hasSparseSemantic()) {
if (reset)
simple.reset(b);
reset = true;
Expand Down Expand Up @@ -670,8 +669,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
const auto lt = getLvlType(b);
if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
isNOutOfMLT(lt))
if (lt.hasSparseSemantic())
return true;
}
return hasSparseIdxReduction(bits);
Expand Down
58 changes: 40 additions & 18 deletions mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ static Match synZeroMatch() { return Match(); }
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
#undef IMPL_BINOP_PATTERN

class MergerTestBase : public ::testing::Test {
// Parameterize LevelFormat to test both Dense and Batch LevelFormat.
class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
protected:
MergerTestBase(unsigned numTensors, unsigned numLoops)
: merger(numTensors, numLoops, /*maxRank=*/numLoops) {
Expand Down Expand Up @@ -317,10 +318,14 @@ class MergerTest3T1L : public MergerTestBase {
// Tensor 1: sparse input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
// Tensor 2: dense output vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
}
};

INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));

/// Four tensors (three inputs, one output); and a single loop.
class MergerTest4T1L : public MergerTestBase {
protected:
Expand All @@ -333,10 +338,14 @@ class MergerTest4T1L : public MergerTestBase {
// Tensor 2: sparse input vector
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
// Tensor 3: dense output vector
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
}
};

INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));

///
/// Tests with both sparse and dense input.
///
Expand All @@ -349,12 +358,16 @@ class MergerTest3T1LD : public MergerTestBase {
// Tensor 0: sparse input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: dense input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
// Tensor 2: dense output vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
}
};

INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));

///
/// Tests with both undef and dense input.
///
Expand All @@ -367,14 +380,18 @@ class MergerTest4T1LU : public MergerTestBase {
// Tensor 0: undef input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
// Tensor 1: dense input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
// Tensor 2: undef input vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
// Tensor 3: dense output vector.
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
}
};

INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));

///
/// Tests with operation on sparse output.
///
Expand All @@ -395,6 +412,11 @@ class MergerTest3T1LSo : public MergerTestBase {
}
};

// This testsuite does not use any dense-like format, just one of {Dense, Batch}
// is enough.
INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
::testing::Values(LevelFormat::Dense));

} // namespace

/// Vector multiplication (conjunction) of 3 vectors, i.e.;
Expand All @@ -409,7 +431,7 @@ class MergerTest3T1LSo : public MergerTestBase {
/// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
Expand Down Expand Up @@ -443,7 +465,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
/// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
Expand Down Expand Up @@ -482,7 +504,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
/// lat( i_01 / tensor_1 )
/// }
#define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \
TEST_F(MergerTest3T1L, vector_##OP) { \
TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
Expand Down Expand Up @@ -514,7 +536,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
/// lat( i_00 i_01 / (tensor_0 * tensor_1) )
/// }
#define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \
TEST_F(MergerTest3T1L, vector_##OP) { \
TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
Expand Down Expand Up @@ -544,7 +566,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
/// lat( i_02 / tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
const auto em = CONJ##Expr(tensor(0), tensor(1)); \
const auto e = DISJ##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
Expand Down Expand Up @@ -587,7 +609,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
/// lat( i_00 / tensor_0 )
/// }
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
const auto e = DISJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
Expand Down Expand Up @@ -636,7 +658,7 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
/// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
Expand Down Expand Up @@ -675,7 +697,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
Expand Down Expand Up @@ -711,7 +733,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
/// }
/// since i_01 is a dense dimension.
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
Expand Down Expand Up @@ -746,7 +768,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
/// lat( i_00 / tensor_0 cmp 0 )
/// lat( i_01 / 0 cmp tensor_1 )
/// }
TEST_F(MergerTest3T1L, vector_cmp) {
TEST_P(MergerTest3T1L, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0);
const auto t0 = tid(0);
Expand Down Expand Up @@ -784,7 +806,7 @@ TEST_F(MergerTest3T1L, vector_cmp) {
///
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
TEST_F(MergerTest3T1LD, vector_cmp) {
TEST_P(MergerTest3T1LD, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0);
const auto t0 = tid(0);
Expand Down