Skip to content

[mlir][sparse] add merger support on Batch LevelType. #83186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Feb 27, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/83186.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+15-3)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+1-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+3-5)
  • (modified) mlir/unittests/Dialect/SparseTensor/MergerTest.cpp (+40-18)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 9e79b6aca1c9ba..5563cb907e9353 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -333,16 +333,28 @@ struct LevelType {
     return lvlBits & static_cast<uint64_t>(p);
   }
 
+  /// Check if the `LevelType` is considered to be sparse.
+  constexpr bool hasSparseSemantic() const {
+    return isa<LevelFormat::Compressed, LevelFormat::Singleton,
+               LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
+  }
+
+  /// Check if the `LevelType` is considered to be dense-like.
+  constexpr bool hasDenseSemantic() const {
+    return isa<LevelFormat::Dense, LevelFormat::Batch>();
+  }
+
   /// Check if the `LevelType` needs positions array.
   constexpr bool isWithPosLT() const {
-    return isa<LevelFormat::Compressed>() ||
-           isa<LevelFormat::LooseCompressed>();
+    assert(!isa<LevelFormat::Undef>());
+    return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
   }
 
   /// Check if the `LevelType` needs coordinates array.
   constexpr bool isWithCrdLT() const {
+    assert(!isa<LevelFormat::Undef>());
     // All sparse levels has coordinate array.
-    return !isa<LevelFormat::Dense, LevelFormat::Batch>();
+    return hasSparseSemantic();
   }
 
   std::string toMLIRString() const {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 490ef3071af1b7..7f9820df984b29 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -509,8 +509,7 @@ class Merger {
   bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
     if (isLvlWithNonTrivialIdxExp(b)) {
       auto lt = getLoopDependentLevelType(b);
-      return isCompressedLT(lt) || isSingletonLT(lt) ||
-             isLooseCompressedLT(lt) || isNOutOfMLT(lt);
+      return lt.hasSparseSemantic();
     }
     return false;
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 731cd79a1e3b4b..72b722c69ae34b 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -476,7 +476,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
     // Starts resetting from a dense level, so that the first bit (if kept)
     // is not undefined level-type.
     for (unsigned b = 0; b < be; b++) {
-      if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) {
+      if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
         offset = be - b - 1; // relative to the end
         break;
       }
@@ -489,8 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
     // Slice on dense level has `locate` property as well, and can be optimized.
     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
       const auto lt = getLvlType(b);
-      if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
-          !isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
+      if (!lt.hasSparseSemantic()) {
         if (reset)
           simple.reset(b);
         reset = true;
@@ -670,8 +669,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
 bool Merger::hasAnySparse(const BitVector &bits) const {
   for (TensorLoopId b : bits.set_bits()) {
     const auto lt = getLvlType(b);
-    if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-        isNOutOfMLT(lt))
+    if (lt.hasSparseSemantic())
       return true;
   }
   return hasSparseIdxReduction(bits);
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 62a19c084cac0f..943e7d5c120b87 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -120,7 +120,8 @@ static Match synZeroMatch() { return Match(); }
 FOREVERY_BINOP(IMPL_BINOP_PATTERN)
 #undef IMPL_BINOP_PATTERN
 
-class MergerTestBase : public ::testing::Test {
+// Parameterize LevelFormat to test both Dense and Batch LevelFormat.
+class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
 protected:
   MergerTestBase(unsigned numTensors, unsigned numLoops)
       : merger(numTensors, numLoops, /*maxRank=*/numLoops) {
@@ -317,10 +318,14 @@ class MergerTest3T1L : public MergerTestBase {
     // Tensor 1: sparse input vector.
     merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
     // Tensor 2: dense output vector.
-    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
   }
 };
 
+INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
+                         ::testing::Values(LevelFormat::Dense,
+                                           LevelFormat::Batch));
+
 /// Four tensors (three inputs, one output); and a single loop.
 class MergerTest4T1L : public MergerTestBase {
 protected:
@@ -333,10 +338,14 @@ class MergerTest4T1L : public MergerTestBase {
     // Tensor 2: sparse input vector
     merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
     // Tensor 3: dense output vector
-    merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
   }
 };
 
+INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
+                         ::testing::Values(LevelFormat::Dense,
+                                           LevelFormat::Batch));
+
 ///
 /// Tests with both sparse and dense input.
 ///
@@ -349,12 +358,16 @@ class MergerTest3T1LD : public MergerTestBase {
     // Tensor 0: sparse input vector.
     merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
     // Tensor 1: dense input vector.
-    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
     // Tensor 2: dense output vector.
-    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
   }
 };
 
+INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
+                         ::testing::Values(LevelFormat::Dense,
+                                           LevelFormat::Batch));
+
 ///
 /// Tests with both undef and dense input.
 ///
@@ -367,14 +380,18 @@ class MergerTest4T1LU : public MergerTestBase {
     // Tensor 0: undef input vector.
     merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
     // Tensor 1: dense input vector.
-    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
     // Tensor 2: undef input vector.
     merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
     // Tensor 3: dense output vector.
-    merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
   }
 };
 
+INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
+                         ::testing::Values(LevelFormat::Dense,
+                                           LevelFormat::Batch));
+
 ///
 /// Tests with operation on sparse output.
 ///
@@ -395,6 +412,11 @@ class MergerTest3T1LSo : public MergerTestBase {
   }
 };
 
+// This testsuite does not use any dense-like format, just one of {Dense, Batch}
+// is enough.
+INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
+                         ::testing::Values(LevelFormat::Dense));
+
 } // namespace
 
 /// Vector multiplication (conjunction) of 3 vectors, i.e.;
@@ -409,7 +431,7 @@ class MergerTest3T1LSo : public MergerTestBase {
 ///   lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
 /// }
 #define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2)                         \
-  TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) {                          \
+  TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) {                          \
     const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
     const auto e = CONJ2##Expr(em, tensor(2));                                 \
     const auto l0 = lid(0);                                                    \
@@ -443,7 +465,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
 ///   lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
 /// }
 #define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2)                    \
-  TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) {                         \
+  TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) {                         \
     const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
     const auto e = CONJ2##Expr(em, tensor(2));                                 \
     const auto l0 = lid(0);                                                    \
@@ -482,7 +504,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
 ///   lat( i_01 / tensor_1 )
 /// }
 #define IMPL_MERGER_TEST_DISJ(OP, UNUSED)                                      \
-  TEST_F(MergerTest3T1L, vector_##OP) {                                        \
+  TEST_P(MergerTest3T1L, vector_##OP) {                                        \
     const auto e = OP##Expr(tensor(0), tensor(1));                             \
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
@@ -514,7 +536,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
 ///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
 /// }
 #define IMPL_MERGER_TEST_CONJ(OP, UNUSED)                                      \
-  TEST_F(MergerTest3T1L, vector_##OP) {                                        \
+  TEST_P(MergerTest3T1L, vector_##OP) {                                        \
     const auto e = OP##Expr(tensor(0), tensor(1));                             \
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
@@ -544,7 +566,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
 ///    lat( i_02 / tensor_2 )
 /// }
 #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ)                                 \
-  TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
+  TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
     const auto em = CONJ##Expr(tensor(0), tensor(1));                          \
     const auto e = DISJ##Expr(em, tensor(2));                                  \
     const auto l0 = lid(0);                                                    \
@@ -587,7 +609,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
 ///   lat( i_00 / tensor_0 )
 /// }
 #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2)                               \
-  TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
+  TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
     const auto em = DISJ1##Expr(tensor(0), tensor(1));                         \
     const auto e = DISJ2##Expr(em, tensor(2));                                 \
     const auto l0 = lid(0);                                                    \
@@ -636,7 +658,7 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
 ///    lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
 /// }
 #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2)                               \
-  TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
+  TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
     const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
     const auto e = CONJ2##Expr(em, tensor(2));                                 \
     const auto l0 = lid(0);                                                    \
@@ -675,7 +697,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
 /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
 #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED)                            \
-  TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
+  TEST_P(MergerTest3T1LD, vector_opted_##OP) {                                 \
     const auto e = OP##Expr(tensor(0), tensor(1));                             \
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
@@ -711,7 +733,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
 /// }
 /// since i_01 is a dense dimension.
 #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED)                            \
-  TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
+  TEST_P(MergerTest3T1LD, vector_opted_##OP) {                                 \
     const auto e = OP##Expr(tensor(0), tensor(1));                             \
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
@@ -746,7 +768,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
 ///   lat( i_00 / tensor_0 cmp 0 )
 ///   lat( i_01 / 0 cmp tensor_1 )
 /// }
-TEST_F(MergerTest3T1L, vector_cmp) {
+TEST_P(MergerTest3T1L, vector_cmp) {
   const auto e = cmpiExpr(tensor(0), tensor(1));
   const auto l0 = lid(0);
   const auto t0 = tid(0);
@@ -784,7 +806,7 @@ TEST_F(MergerTest3T1L, vector_cmp) {
 ///
 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
 /// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
-TEST_F(MergerTest3T1LD, vector_cmp) {
+TEST_P(MergerTest3T1LD, vector_cmp) {
   const auto e = cmpiExpr(tensor(0), tensor(1));
   const auto l0 = lid(0);
   const auto t0 = tid(0);

@PeimingLiu PeimingLiu merged commit d82e93e into llvm:main Feb 27, 2024
@PeimingLiu PeimingLiu deleted the merger-batch branch February 27, 2024 22:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants