Skip to content

[mlir][tosa] Canonicalise slice over overlapped or inside a pad. #138270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

GeorgeARM
Copy link
Contributor

Update the paddings and/or the slice parameters when a tosa.slice after a tosa.pad is accessing only an overlapping or not region of the padded tensor.

Update the paddings and/or the slice parameters when a `tosa.slice` after a
`tosa.pad` is accessing only an overlapping or not region of the padded
tensor.

Signed-off-by: Georgios Pinitas <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Georgios Pinitas (GeorgeARM)

Changes

Update the paddings and/or the slice parameters when a tosa.slice after a tosa.pad is accessing only an overlapping or not region of the padded tensor.


Full diff: https://github.com/llvm/llvm-project/pull/138270.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+123-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+36)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 47368532df169..5347fb1c16698 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -731,6 +731,127 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
   }
 };
 
+struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    Value sliceInput = sliceOp.getInput1();
+
+    // Check if producer is a PadOp
+    auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
+    if (!padOp)
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "slice input must be a pad operation");
+
+    // Check PadOp has a single consumer
+    if (!padOp->hasOneUse())
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "pad shall have a single consumer");
+
+    // Check input is statically ranked
+    auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
+    auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
+    if (!inputTy || !padTy)
+      return rewriter.notifyMatchFailure(
+          sliceOp, "slice input must be a static ranked tensor");
+
+    // Validate and extract tosa::PadOp padding
+    DenseIntElementsAttr paddingElems;
+    if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+      return rewriter.notifyMatchFailure(
+          sliceOp,
+          "The `padding` input specified on the tosa::PadOp must be constant.");
+    }
+    llvm::SmallVector<int64_t> padPaddings =
+        llvm::to_vector(paddingElems.getValues<int64_t>());
+
+    // Extract slice parameters
+    DenseElementsAttr startElems;
+    if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "start of slice must be a static ranked shape");
+    llvm::SmallVector<int64_t> sliceStarts =
+        llvm::to_vector(startElems.getValues<int64_t>());
+
+    DenseElementsAttr sizeElems;
+    if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "size of slice must be a static ranked shape");
+    llvm::SmallVector<int64_t> sliceSizes =
+        llvm::to_vector(sizeElems.getValues<int64_t>());
+
+    // Update the paddings
+    int64_t rank = inputTy.getRank();
+    llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
+    llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
+    llvm::SmallVector<int64_t> newPadShape(rank, 0);
+    bool updated = false;
+    for (int64_t i = 0; i < rank; ++i) {
+      const int64_t padLo = padPaddings[i * 2];
+      const int64_t padHi = padPaddings[i * 2 + 1];
+      const int64_t sliceStart = sliceStarts[i];
+      const int64_t sliceSize = sliceSizes[i];
+      const int64_t sliceEnd = sliceStart + sliceSize;
+
+      const int64_t dimSize = inputTy.getShape()[i];
+      const int64_t dimStart = padLo;
+      const int64_t dimEnd = padLo + dimSize;
+      const int64_t dimTotal = padLo + dimSize + padHi;
+
+      // Check slice within bounds
+      if (sliceStart < 0 || sliceEnd > dimTotal)
+        return rewriter.notifyMatchFailure(sliceOp, "slice out-of-bounds");
+
+      const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
+      const int64_t newPadHi =
+          std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
+      const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
+
+      // Compute update slice/pad parameters
+      if (sliceStart < dimStart || sliceEnd > dimEnd) {
+        // Handle slice when not within the original input entirely
+        updated |= (newPadLo != padLo) || (newPadHi != padHi) ||
+                   (newSliceStart != sliceStart);
+        newPadPaddings[i * 2] = newPadLo;
+        newPadPaddings[i * 2 + 1] = newPadHi;
+        newSliceStarts[i] = newSliceStart;
+      } else {
+        // Slice is within the original input
+        updated |= newSliceStart != sliceStart;
+        newSliceStarts[i] = newSliceStart;
+      }
+
+      // Calculate new pad output shape
+      newPadShape[i] =
+          newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
+    }
+
+    // Check that we actually need to proceed with the rewrite
+    if (!updated)
+      return rewriter.notifyMatchFailure(
+          sliceOp, "terminate condition; nothing to rewrite");
+
+    // Create a PadOp with updated padding
+    auto newPaddingsOp =
+        getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
+    auto newPadTy =
+        RankedTensorType::get(newPadShape, inputTy.getElementType());
+    auto newPadOp = rewriter.create<tosa::PadOp>(
+        padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
+        padOp.getPadConst());
+
+    // Update SliceOp and point to new PadOp
+    auto newStartOp =
+        getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
+    rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
+                                               newPadOp.getResult(), newStartOp,
+                                               sliceOp.getSize());
+
+    return success();
+  }
+};
+
 // Update size operand of tosa.slice if size has dynamic dims but corresponding
 // output dim is static
 struct SliceDynamicSizeCanonicalization
@@ -779,8 +900,8 @@ struct SliceDynamicSizeCanonicalization
 
 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
-  results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
-      context);
+  results.add<ConcatSliceOptimization, PadSliceOptimization,
+              SliceDynamicSizeCanonicalization>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 59fd490330691..6e99f57341982 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -985,6 +985,42 @@ func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf
 
 // -----
 
+// CHECK-LABEL: @canonicalize_pad_slice_overlap
+// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape  {values = dense<[1, 14, 18, 3]> : tensor<4xindex>}
+// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
+// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_overlap(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x18x3xf32> {
+  %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %padding = tosa.const_shape  {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+  %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
+  %start = tosa.const_shape  {values = dense<[0, 0, 1, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %size = tosa.const_shape  {values = dense<[1, 14, 18, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x18x3xf32>
+  return %sliced : tensor<1x14x18x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_pad_slice_inside
+// CHECK-DAG: %[[SLICE_START:.*]] = tosa.const_shape  {values = dense<[0, 1, 2, 0]> : tensor<4xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape  {values = dense<[1, 14, 10, 3]> : tensor<4xindex>}
+// CHECK-NOT: tosa.pad
+// CHECK: %[[SLICED:.*]] = tosa.slice %arg0, %[[SLICE_START]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_inside(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x14x3xf32> {
+  %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %padding = tosa.const_shape  {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+  %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
+  %start = tosa.const_shape  {values = dense<[0, 1, 4, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %size = tosa.const_shape  {values = dense<[1, 14, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x14x3xf32>
+  return %sliced : tensor<1x14x14x3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_log_exp
 func.func @fold_log_exp(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg{{.*}} : tensor<?x1xf32>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants