diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 47368532df169..5347fb1c16698 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -731,6 +731,127 @@ struct ConcatSliceOptimization : public OpRewritePattern { } }; +struct PadSliceOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const override { + Value sliceInput = sliceOp.getInput1(); + + // Check if producer is a PadOp + auto padOp = sliceInput.getDefiningOp(); + if (!padOp) + return rewriter.notifyMatchFailure(sliceOp, + "slice input must be a pad operation"); + + // Check PadOp has a single consumer + if (!padOp->hasOneUse()) + return rewriter.notifyMatchFailure(sliceOp, + "pad shall have a single consumer"); + + // Check input is statically ranked + auto inputTy = dyn_cast(padOp.getInput1().getType()); + auto padTy = dyn_cast(padOp.getType()); + if (!inputTy || !padTy) + return rewriter.notifyMatchFailure( + sliceOp, "slice input must be a static ranked tensor"); + + // Validate and extract tosa::PadOp padding + DenseIntElementsAttr paddingElems; + if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) { + return rewriter.notifyMatchFailure( + sliceOp, + "The `padding` input specified on the tosa::PadOp must be constant."); + } + llvm::SmallVector padPaddings = + llvm::to_vector(paddingElems.getValues()); + + // Extract slice parameters + DenseElementsAttr startElems; + if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems))) + return rewriter.notifyMatchFailure( + sliceOp, "start of slice must be a static ranked shape"); + llvm::SmallVector sliceStarts = + llvm::to_vector(startElems.getValues()); + + DenseElementsAttr sizeElems; + if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) + return rewriter.notifyMatchFailure( + sliceOp, "size of slice must be a static ranked shape"); + llvm::SmallVector sliceSizes = + llvm::to_vector(sizeElems.getValues()); + + // Update the paddings + int64_t rank = inputTy.getRank(); + llvm::SmallVector newSliceStarts(rank, 0); + llvm::SmallVector newPadPaddings(2 * rank, 0); + llvm::SmallVector newPadShape(rank, 0); + bool updated = false; + for (int64_t i = 0; i < rank; ++i) { + const int64_t padLo = padPaddings[i * 2]; + const int64_t padHi = padPaddings[i * 2 + 1]; + const int64_t sliceStart = sliceStarts[i]; + const int64_t sliceSize = sliceSizes[i]; + const int64_t sliceEnd = sliceStart + sliceSize; + + const int64_t dimSize = inputTy.getShape()[i]; + const int64_t dimStart = padLo; + const int64_t dimEnd = padLo + dimSize; + const int64_t dimTotal = padLo + dimSize + padHi; + + // Check slice within bounds + if (sliceStart < 0 || sliceEnd > dimTotal) + return rewriter.notifyMatchFailure(sliceOp, "slice out-of-bounds"); + + const int64_t newPadLo = std::max(padLo - sliceStart, 0); + const int64_t newPadHi = + std::max(sliceEnd - (padLo + dimSize), 0); + const int64_t newSliceStart = std::max(sliceStart - padLo, 0); + + // Compute update slice/pad parameters + if (sliceStart < dimStart || sliceEnd > dimEnd) { + // Handle slice when not within the original input entirely + updated |= (newPadLo != padLo) || (newPadHi != padHi) || + (newSliceStart != sliceStart); + newPadPaddings[i * 2] = newPadLo; + newPadPaddings[i * 2 + 1] = newPadHi; + newSliceStarts[i] = newSliceStart; + } else { + // Slice is within the original input + updated |= newSliceStart != sliceStart; + newSliceStarts[i] = newSliceStart; + } + + // Calculate new pad output shape + newPadShape[i] = + newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1]; + } + + // Check that we actually need to proceed with the rewrite + if (!updated) + return rewriter.notifyMatchFailure( + sliceOp, "terminate condition; nothing to rewrite"); + + // Create a PadOp with updated padding + auto newPaddingsOp = + getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings); + auto newPadTy = + RankedTensorType::get(newPadShape, inputTy.getElementType()); + auto newPadOp = rewriter.create( + padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp, + padOp.getPadConst()); + + // Update SliceOp and point to new PadOp + auto newStartOp = + getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts); + rewriter.replaceOpWithNewOp(sliceOp, sliceOp.getType(), + newPadOp.getResult(), newStartOp, + sliceOp.getSize()); + + return success(); + } +}; + // Update size operand of tosa.slice if size has dynamic dims but corresponding // output dim is static struct SliceDynamicSizeCanonicalization @@ -779,8 +900,8 @@ struct SliceDynamicSizeCanonicalization void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add( - context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 59fd490330691..6e99f57341982 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -985,6 +985,42 @@ func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf // ----- +// CHECK-LABEL: @canonicalize_pad_slice_overlap +// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>} +// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 14, 18, 3]> : tensor<4xindex>} +// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]] +// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]] +func.func @canonicalize_pad_slice_overlap(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x18x3xf32> { + %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32> + %start = tosa.const_shape {values = dense<[0, 0, 1, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %size = tosa.const_shape {values = dense<[1, 14, 18, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> + %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x18x3xf32> + return %sliced : tensor<1x14x18x3xf32> +} + +// ----- + +// CHECK-LABEL: @canonicalize_pad_slice_inside +// CHECK-DAG: %[[SLICE_START:.*]] = tosa.const_shape {values = dense<[0, 1, 2, 0]> : tensor<4xindex>} +// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>} +// CHECK-NOT: tosa.pad +// CHECK: %[[SLICED:.*]] = tosa.slice %arg0, %[[SLICE_START]], %[[SLICE_SIZE]] +func.func @canonicalize_pad_slice_inside(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x14x3xf32> { + %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32> + %start = tosa.const_shape {values = dense<[0, 1, 4, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %size = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> + %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x14x3xf32> + return %sliced : tensor<1x14x14x3xf32> +} + +// ----- + // CHECK-LABEL: @fold_log_exp func.func @fold_log_exp(%arg0: tensor) -> tensor { // CHECK: return %arg{{.*}} : tensor