Skip to content

Revert "[mlir][sparse] implement lowering rules for IterateOp." #95826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2024

Conversation

PeimingLiu
Copy link
Member

Reverts #95286

@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Jun 17, 2024
@PeimingLiu PeimingLiu merged commit 996905d into main Jun 17, 2024
@PeimingLiu PeimingLiu deleted the revert-95286-lower-iter branch June 17, 2024 18:35
@llvmbot
Copy link
Member

llvmbot commented Jun 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Reverts llvm/llvm-project#95286


Full diff: https://github.com/llvm/llvm-project/pull/95826.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+1-120)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (-40)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (+3-23)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir (+13-41)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index f57be49f21b8c..62887c75c872b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -34,20 +34,6 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
   return success();
 }
 
-static std::optional<LogicalResult>
-convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
-  // The actually Iterator Values (that are updated every iteration).
-  auto idxTp = IndexType::get(itTp.getContext());
-  // TODO: handle batch dimension.
-  assert(itTp.getEncoding().getBatchLvlRank() == 0);
-  if (!itTp.isUnique()) {
-    // Segment high for non-unique iterator.
-    fields.push_back(idxTp);
-  }
-  fields.push_back(idxTp);
-  return success();
-}
-
 namespace {
 
 /// Sparse codegen rule for number of entries operator.
@@ -71,114 +57,10 @@ class ExtractIterSpaceConverter
   }
 };
 
-class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
-public:
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
-  LogicalResult
-  matchAndRewrite(IterateOp op, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
-    if (!op.getCrdUsedLvls().empty())
-      return rewriter.notifyMatchFailure(
-          op, "non-empty coordinates list not implemented.");
-
-    Location loc = op.getLoc();
-
-    auto iterSpace = SparseIterationSpace::fromValues(
-        op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
-
-    std::unique_ptr<SparseIterator> it =
-        iterSpace.extractIterator(rewriter, loc);
-
-    if (it->iteratableByFor()) {
-      auto [lo, hi] = it->genForCond(rewriter, loc);
-      Value step = constantIndex(rewriter, loc, 1);
-      SmallVector<Value> ivs;
-      for (ValueRange inits : adaptor.getInitArgs())
-        llvm::append_range(ivs, inits);
-      scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
-
-      Block *loopBody = op.getBody();
-      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-      if (failed(typeConverter->convertSignatureArgs(
-              loopBody->getArgumentTypes(), bodyTypeMapping)))
-        return failure();
-      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-      forOp.getBody()->erase();
-      Region &dstRegion = forOp.getRegion();
-      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
-      auto yieldOp =
-          llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
-
-      rewriter.setInsertionPointToEnd(forOp.getBody());
-      // replace sparse_tensor.yield with scf.yield.
-      rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
-      yieldOp.erase();
-
-      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-      rewriter.replaceOp(op, forOp.getResults(), resultMapping);
-    } else {
-      SmallVector<Value> ivs;
-      llvm::append_range(ivs, it->getCursor());
-      for (ValueRange inits : adaptor.getInitArgs())
-        llvm::append_range(ivs, inits);
-
-      assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
-      TypeRange types = ValueRange(ivs).getTypes();
-      auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
-      SmallVector<Location> l(types.size(), op.getIterator().getLoc());
-
-      // Generates loop conditions.
-      Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
-      rewriter.setInsertionPointToStart(before);
-      ValueRange bArgs = before->getArguments();
-      auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
-      assert(remArgs.size() == adaptor.getInitArgs().size());
-      rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
-
-      // Generates loop body.
-      Block *loopBody = op.getBody();
-      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-      if (failed(typeConverter->convertSignatureArgs(
-              loopBody->getArgumentTypes(), bodyTypeMapping)))
-        return failure();
-      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-      Region &dstRegion = whileOp.getAfter();
-      // TODO: handle uses of coordinate!
-      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-      ValueRange aArgs = whileOp.getAfterArguments();
-      auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
-          whileOp.getAfterBody()->getTerminator());
-
-      rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
-
-      aArgs = it->linkNewScope(aArgs);
-      ValueRange nx = it->forward(rewriter, loc);
-      SmallVector<Value> yields;
-      llvm::append_range(yields, nx);
-      llvm::append_range(yields, yieldOp.getResults());
-
-      // replace sparse_tensor.yield with scf.yield.
-      yieldOp->erase();
-      rewriter.create<scf::YieldOp>(loc, yields);
-
-      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-      rewriter.replaceOp(
-          op, whileOp.getResults().drop_front(it->getCursor().size()),
-          resultMapping);
-    }
-    return success();
-  }
-};
-
 } // namespace
 
 mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
   addConversion([](Type type) { return type; });
-  addConversion(convertIteratorType);
   addConversion(convertIterSpaceType);
 
   addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
@@ -192,6 +74,5 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
 
 void mlir::populateLowerSparseIterationToSCFPatterns(
     TypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
-      converter, patterns.getContext());
+  patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index ef95fcc84bd90..be8e15d6ae6f4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -331,13 +331,6 @@ class TrivialIterator : public ConcreteIterator {
   TrivialIterator(const SparseTensorLevel &stl)
       : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
 
-  TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
-                  Value posLo, Value posHi)
-      : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
-        posHi(posHi) {
-    seek(posLo);
-  }
-
   std::string getDebugInterfacePrefix() const override {
     return std::string("trivial<") + stl.toString() + ">";
   }
@@ -427,14 +420,6 @@ class DedupIterator : public ConcreteIterator {
       : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
     assert(!stl.isUnique());
   }
-
-  DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
-                Value posLo, Value posHi)
-      : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
-    assert(!stl.isUnique());
-    seek({posLo, genSegmentHigh(b, l, posLo)});
-  }
-
   // For LLVM-style RTTI.
   static bool classof(const SparseIterator *from) {
     return from->kind == IterKind::kDedup;
@@ -1547,11 +1532,6 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
   return space;
 }
 
-std::unique_ptr<SparseIterator>
-SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
-  return makeSimpleIterator(b, l, *this);
-}
-
 //===----------------------------------------------------------------------===//
 // SparseIterator factory functions.
 //===----------------------------------------------------------------------===//
@@ -1610,26 +1590,6 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
   return std::make_pair(std::move(stl), std::move(it));
 }
 
-std::unique_ptr<SparseIterator>
-sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
-                                  const SparseIterationSpace &iterSpace) {
-  // assert(iterSpace.getSpaceDim() == 1);
-  std::unique_ptr<SparseIterator> ret;
-  if (!iterSpace.isUnique()) {
-    // We always dedupliate the non-unique level, but we should optimize it away
-    // if possible.
-    ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
-                                          iterSpace.getBoundLo(),
-                                          iterSpace.getBoundHi());
-  } else {
-    ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
-                                            iterSpace.getBoundLo(),
-                                            iterSpace.getBoundHi());
-  }
-  ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
-  return ret;
-}
-
 std::unique_ptr<SparseIterator>
 sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
                                   SparseEmitStrategy strategy) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 91f363db93f1d..17636af2b2f9d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -132,10 +132,6 @@ class SparseIterationSpace {
   Value getBoundLo() const { return bound.first; }
   Value getBoundHi() const { return bound.second; }
 
-  // Extract an iterator to iterate over the sparse iteration space.
-  std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
-                                                  Location l) const;
-
 private:
   SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
   std::pair<Value, Value> bound;
@@ -196,13 +192,6 @@ class SparseIterator {
     crd = nullptr;
   }
 
-  // Reconstructs a iteration space directly from the provided ValueRange.
-  static std::unique_ptr<SparseIterator>
-  fromValues(IteratorType dstTp, ValueRange values, unsigned tid);
-
-  // The inverse operation of `fromValues`.
-  SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }
-
   //
   // Iterator properties.
   //
@@ -356,21 +345,12 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
                                                          unsigned tid,
                                                          Level lvl);
 
-/// Helper function to create a TensorLevel object from given ValueRange.
+/// Helper function to create a TensorLevel object from given `tensor`.
 std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
                                                          ValueRange buffers,
                                                          unsigned tid, Level l);
-
-/// Helper function to create a simple SparseIterator object that iterate
-/// over the entire iteration space.
-std::unique_ptr<SparseIterator>
-makeSimpleIterator(OpBuilder &b, Location l,
-                   const SparseIterationSpace &iterSpace);
-
-/// Helper function to create a simple SparseIterator object that iterate
-/// over the sparse tensor level.
-/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
-/// feature complete.
+/// Helper function to create a simple SparseIterator object that iterates
+/// over the SparseTensorLevel.
 std::unique_ptr<SparseIterator> makeSimpleIterator(
     const SparseTensorLevel &stl,
     SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
index 77a0e89dc7c81..5fcd661bb69b2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -1,5 +1,4 @@
 // RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
-// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED
 
 #COO = #sparse_tensor.encoding<{
   map = (i, j) -> (
@@ -8,44 +7,17 @@
   )
 }>
 
-// CHECK-LABEL:   @sparse_iteration_to_scf
-//                  // deduplication
-// CHECK:           scf.while {{.*}} {
-// CHECK:           } do {
-// CHECK:           }
-// CHECK:           scf.while {{.*}} {
-// CHECK:           } do {
-//                    // actual computation
-// CHECK:             scf.for {{.*}} {
-// CHECK:               arith.addi
-// CHECK:             }
-//                    // deduplication
-// CHECK:             scf.while {{.*}} {
-// CHECK:             } do {
-// CHECK:             }
-// CHECK:             scf.yield
-// CHECK:           }
-// CHECK:           return
-
-// COLLAPSED-LABEL:   @sparse_iteration_to_scf
-// COLLAPSED:           %[[RET:.*]] = scf.for {{.*}} {
-// COLLAPSED:             %[[VAL:.*]] = arith.addi
-// COLLAPSED:             scf.yield %[[VAL]] : index
-// COLLAPSED:           }
-// COLLAPSED:           return %[[RET]] : index
-func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index {
-  %i = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
-      : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
-  %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
-    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
-        : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1>
-    %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
-      %k = arith.addi %inner, %c1 : index
-      sparse_tensor.yield %k : index
-    }
-    sparse_tensor.yield %r2 : index
-  }
-  return %r1 : index
+// CHECK-LABEL:   func.func @sparse_1D_space(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:           %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:           %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK:           %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
+func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
+  return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
 }

PeimingLiu pushed a commit to PeimingLiu/llvm-project that referenced this pull request Jun 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants