-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][sparse] add merger support on Batch LevelType. #83186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesFull diff: https://github.com/llvm/llvm-project/pull/83186.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 9e79b6aca1c9ba..5563cb907e9353 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -333,16 +333,28 @@ struct LevelType {
return lvlBits & static_cast<uint64_t>(p);
}
+ /// Check if the `LevelType` is considered to be sparse.
+ constexpr bool hasSparseSemantic() const {
+ return isa<LevelFormat::Compressed, LevelFormat::Singleton,
+ LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
+ }
+
+ /// Check if the `LevelType` is considered to be dense-like.
+ constexpr bool hasDenseSemantic() const {
+ return isa<LevelFormat::Dense, LevelFormat::Batch>();
+ }
+
/// Check if the `LevelType` needs positions array.
constexpr bool isWithPosLT() const {
- return isa<LevelFormat::Compressed>() ||
- isa<LevelFormat::LooseCompressed>();
+ assert(!isa<LevelFormat::Undef>());
+ return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
}
/// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT() const {
+ assert(!isa<LevelFormat::Undef>());
// All sparse levels has coordinate array.
- return !isa<LevelFormat::Dense, LevelFormat::Batch>();
+ return hasSparseSemantic();
}
std::string toMLIRString() const {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 490ef3071af1b7..7f9820df984b29 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -509,8 +509,7 @@ class Merger {
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
if (isLvlWithNonTrivialIdxExp(b)) {
auto lt = getLoopDependentLevelType(b);
- return isCompressedLT(lt) || isSingletonLT(lt) ||
- isLooseCompressedLT(lt) || isNOutOfMLT(lt);
+ return lt.hasSparseSemantic();
}
return false;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 731cd79a1e3b4b..72b722c69ae34b 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -476,7 +476,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Starts resetting from a dense level, so that the first bit (if kept)
// is not undefined level-type.
for (unsigned b = 0; b < be; b++) {
- if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) {
+ if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
offset = be - b - 1; // relative to the end
break;
}
@@ -489,8 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Slice on dense level has `locate` property as well, and can be optimized.
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto lt = getLvlType(b);
- if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
- !isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
+ if (!lt.hasSparseSemantic()) {
if (reset)
simple.reset(b);
reset = true;
@@ -670,8 +669,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
const auto lt = getLvlType(b);
- if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- isNOutOfMLT(lt))
+ if (lt.hasSparseSemantic())
return true;
}
return hasSparseIdxReduction(bits);
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 62a19c084cac0f..943e7d5c120b87 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -120,7 +120,8 @@ static Match synZeroMatch() { return Match(); }
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
#undef IMPL_BINOP_PATTERN
-class MergerTestBase : public ::testing::Test {
+// Parameterize LevelFormat to test both Dense and Batch LevelFormat.
+class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
protected:
MergerTestBase(unsigned numTensors, unsigned numLoops)
: merger(numTensors, numLoops, /*maxRank=*/numLoops) {
@@ -317,10 +318,14 @@ class MergerTest3T1L : public MergerTestBase {
// Tensor 1: sparse input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
// Tensor 2: dense output vector.
- merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
}
};
+INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
+ ::testing::Values(LevelFormat::Dense,
+ LevelFormat::Batch));
+
/// Four tensors (three inputs, one output); and a single loop.
class MergerTest4T1L : public MergerTestBase {
protected:
@@ -333,10 +338,14 @@ class MergerTest4T1L : public MergerTestBase {
// Tensor 2: sparse input vector
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
// Tensor 3: dense output vector
- merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
}
};
+INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
+ ::testing::Values(LevelFormat::Dense,
+ LevelFormat::Batch));
+
///
/// Tests with both sparse and dense input.
///
@@ -349,12 +358,16 @@ class MergerTest3T1LD : public MergerTestBase {
// Tensor 0: sparse input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: dense input vector.
- merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
// Tensor 2: dense output vector.
- merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
}
};
+INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
+ ::testing::Values(LevelFormat::Dense,
+ LevelFormat::Batch));
+
///
/// Tests with both undef and dense input.
///
@@ -367,14 +380,18 @@ class MergerTest4T1LU : public MergerTestBase {
// Tensor 0: undef input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
// Tensor 1: dense input vector.
- merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
// Tensor 2: undef input vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
// Tensor 3: dense output vector.
- merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
}
};
+INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
+ ::testing::Values(LevelFormat::Dense,
+ LevelFormat::Batch));
+
///
/// Tests with operation on sparse output.
///
@@ -395,6 +412,11 @@ class MergerTest3T1LSo : public MergerTestBase {
}
};
+// This testsuite does not use any dense-like format, just one of {Dense, Batch}
+// is enough.
+INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
+ ::testing::Values(LevelFormat::Dense));
+
} // namespace
/// Vector multiplication (conjunction) of 3 vectors, i.e.;
@@ -409,7 +431,7 @@ class MergerTest3T1LSo : public MergerTestBase {
/// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
- TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
+ TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@@ -443,7 +465,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
/// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
- TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
+ TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@@ -482,7 +504,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
/// lat( i_01 / tensor_1 )
/// }
#define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \
- TEST_F(MergerTest3T1L, vector_##OP) { \
+ TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@@ -514,7 +536,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
/// lat( i_00 i_01 / (tensor_0 * tensor_1) )
/// }
#define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \
- TEST_F(MergerTest3T1L, vector_##OP) { \
+ TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@@ -544,7 +566,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
/// lat( i_02 / tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
- TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
+ TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
const auto em = CONJ##Expr(tensor(0), tensor(1)); \
const auto e = DISJ##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@@ -587,7 +609,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
/// lat( i_00 / tensor_0 )
/// }
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
- TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
+ TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
const auto e = DISJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@@ -636,7 +658,7 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
/// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
- TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
+ TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@@ -675,7 +697,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \
- TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
+ TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@@ -711,7 +733,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
/// }
/// since i_01 is a dense dimension.
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \
- TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
+ TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@@ -746,7 +768,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
/// lat( i_00 / tensor_0 cmp 0 )
/// lat( i_01 / 0 cmp tensor_1 )
/// }
-TEST_F(MergerTest3T1L, vector_cmp) {
+TEST_P(MergerTest3T1L, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0);
const auto t0 = tid(0);
@@ -784,7 +806,7 @@ TEST_F(MergerTest3T1L, vector_cmp) {
///
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
-TEST_F(MergerTest3T1LD, vector_cmp) {
+TEST_P(MergerTest3T1LD, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0);
const auto t0 = tid(0);
|
aartbik
approved these changes
Feb 27, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.