From 63c3a8ed60f06bfe30edfb6b8fb272dcf34778e9 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Fri, 15 Sep 2023 18:24:23 -0700 Subject: [PATCH] [mlir][TilingInterface] Make the tiling set tile sizes function use `OpFoldResult`. --- .../SCF/Transforms/TileUsingInterface.h | 11 +--- .../TransformOps/LinalgTransformOps.cpp | 24 ++++----- .../SCF/Transforms/TileUsingInterface.cpp | 50 ++++++++----------- .../Dialect/Linalg/transform-op-tile.mlir | 4 +- .../TilingInterface/TestTilingInterface.cpp | 14 ++++-- 5 files changed, 47 insertions(+), 56 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index e7bcd062d9652..ca641c596c7b7 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -26,7 +26,7 @@ namespace mlir { namespace scf { using SCFTileSizeComputationFunction = - std::function(OpBuilder &, Operation *)>; + std::function(OpBuilder &, Operation *)>; /// Options to use to control tiling. struct SCFTilingOptions { @@ -40,17 +40,10 @@ struct SCFTilingOptions { tileSizeComputationFunction = std::move(fun); return *this; } - /// Set the `tileSizeComputationFunction` to return the values `ts`. The - /// values must not fold away when tiling. Otherwise, use a more robust - /// `tileSizeComputationFunction`. - SCFTilingOptions &setTileSizes(const SmallVector &ts) { - tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; - return *this; - } /// Convenience function to set the `tileSizeComputationFunction` to a /// function that computes tile sizes at the point they are needed. Allows /// proper interaction with folding. - SCFTilingOptions &setTileSizes(ArrayRef ts); + SCFTilingOptions &setTileSizes(ArrayRef ts); /// The interchange vector to reorder the tiled loops. SmallVector interchangeVector = {}; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index de4270ab38004..1819ca614a060 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -473,7 +473,9 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; - tilingOptions = tilingOptions.setTileSizes(tileSizes); + SmallVector tileSizesOfr = + getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); + tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; LogicalResult result = applyTilingToAll( @@ -923,7 +925,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, auto nextProducer = getNextProducer(); if (failed(nextProducer)) { auto diag = mlir::emitSilenceableFailure(getLoc()) - << "could not find next producer to fuse into container"; + << "could not find next producer to fuse into container"; diag.attachNote(containingOp->getLoc()) << "containing op"; return diag; } @@ -1999,7 +2001,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, transform::TransformState &state) { scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) { - SmallVector tileSizes; + SmallVector tileSizes; Location loc = target.getLoc(); SmallVector allShapeSizes = target.createFlatListOfOperandDims(b, loc); @@ -2012,9 +2014,8 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, // If the shape size is dynamic, tile by 1. // Otherwise, do not tile (i.e. tile size 0). for (OpFoldResult shapeSize : shapeSizes) { - tileSizes.push_back(getConstantIntValue(shapeSize) - ? b.create(loc, 0) - : b.create(loc, 1)); + tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0) + : b.getIndexAttr(1)); } return tileSizes; }); @@ -2549,7 +2550,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter, if (!tileSizes.empty()) { tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b, Operation *) { - SmallVector sizes; + SmallVector sizes; sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; @@ -2560,10 +2561,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter, getLoc(), attr.cast().getInt()); Value vscale = b.create(getLoc(), b.getIndexType()); - sizes.push_back(b.create(getLoc(), val, vscale)); + sizes.push_back( + b.create(getLoc(), val, vscale).getResult()); } else { - sizes.push_back(b.create( - getLoc(), cast(attr).getInt())); + sizes.push_back(attr); } continue; } @@ -2573,8 +2574,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter, assert((dynamicSizes.empty() ^ params.empty()) && "expected either dynamic sizes or parameters"); if (!params.empty()) { - sizes.push_back( - b.create(getLoc(), params[index])); + sizes.push_back(b.getIndexAttr(params[index])); } else { sizes.push_back(dynamicSizes[index]->getResult(0)); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 1ce25565edcaf..c782583c32eb6 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -31,19 +31,11 @@ using namespace mlir; scf::SCFTilingOptions & -scf::SCFTilingOptions::setTileSizes(ArrayRef ts) { +scf::SCFTilingOptions::setTileSizes(ArrayRef ts) { assert(!tileSizeComputationFunction && "tile sizes already set"); - SmallVector tileSizes(ts.begin(), ts.end()); + auto tileSizes = llvm::to_vector(ts); tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart( - &op->getParentWithTrait() - ->getRegion(0) - .front()); - return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { - Value v = b.create(op->getLoc(), s); - return v; - })); + return tileSizes; }; return *this; } @@ -108,17 +100,16 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, /// Generate an empty loop nest that represents the tiled loop nest shell. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. -/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. +/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. /// - In `offsets` and `sizes` return the multi-dimensional offset and size of /// the /// tile processed within the inner most loop. -static SmallVector -generateTileLoopNest(OpBuilder &builder, Location loc, - ArrayRef loopRanges, ArrayRef tileSizeVals, - SmallVector &offsets, - SmallVector &sizes) { +static SmallVector generateTileLoopNest( + OpBuilder &builder, Location loc, ArrayRef loopRanges, + ArrayRef tileSizes, SmallVector &offsets, + SmallVector &sizes) { assert(!loopRanges.empty() && "expected at least one loop range"); - assert(loopRanges.size() == tileSizeVals.size() && + assert(loopRanges.size() == tileSizes.size() && "expected as many tile sizes as loop ranges"); OpBuilder::InsertionGuard guard(builder); SmallVector loops; @@ -130,7 +121,8 @@ generateTileLoopNest(OpBuilder &builder, Location loc, getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); Value size = getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); - Value tileSize = tileSizeVals[loopRange.index()]; + Value tileSize = getValueOrCreateConstantIndexOp( + builder, loc, tileSizes[loopRange.index()]); // No loops if tile size is zero. Set offset and size to the loop // offset and size. if (matchPattern(tileSize, m_Zero())) { @@ -296,10 +288,10 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, // skips tiling a particular dimension. This convention is significantly // simpler to handle instead of adjusting affine maps to account for missing // dimensions. - SmallVector tileSizeVector = + SmallVector tileSizeVector = options.tileSizeComputationFunction(rewriter, op); if (tileSizeVector.size() < iterationDomain.size()) { - auto zero = rewriter.create(op.getLoc(), 0); + auto zero = rewriter.getIndexAttr(0); tileSizeVector.append(numLoops - tileSizeVector.size(), zero); } @@ -402,17 +394,17 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, FailureOr mlir::scf::tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, - ArrayRef tileSize) { + ArrayRef tileSizes) { Location loc = op.getLoc(); // Ops implementing PartialReductionOpInterface are expected to implement // TilingInterface. auto tilingInterfaceOp = cast(op.getOperation()); SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); - SmallVector tileSizeVector = - getValueOrCreateConstantIndexOp(b, loc, tileSize); - if (tileSizeVector.size() < iterationDomain.size()) { - auto zero = b.create(loc, 0); - tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero); + auto tileSizesVector = llvm::to_vector(tileSizes); + if (tileSizesVector.size() < iterationDomain.size()) { + auto zero = b.getIndexAttr(0); + tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), + zero); } if (op->getNumResults() != 1) return b.notifyMatchFailure( @@ -429,7 +421,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, // 1. create the inital tensor value. FailureOr identityTensor = - op.generateInitialTensorForPartialReduction(b, loc, tileSize, + op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, reductionDims); if (failed(identityTensor)) return b.notifyMatchFailure(op, @@ -437,7 +429,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, // 2. Create the nested loops. SmallVector offsets, sizes; SmallVector loops = generateTileLoopNest( - b, loc, iterationDomain, tileSizeVector, offsets, sizes); + b, loc, iterationDomain, tileSizesVector, offsets, sizes); // 3. Generate the tiled implementation within the inner most loop. b.setInsertionPoint(loops.back().getBody()->getTerminator()); diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir index ce2a3d6ca9c58..9df19632506a7 100644 --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -190,16 +190,16 @@ transform.sequence failures(propagate) { // ----- // CHECK-LABEL: func.func @scalable_and_fixed_length_tile -// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index -// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[VS:.*]] = vector.vscale // CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index // CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]] // CHECK: %[[C0_1:.*]] = arith.constant 0 : index // CHECK: %[[C128_1:.*]] = arith.constant 128 : index +// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index // CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]] // CHECK: %[[C0_2:.*]] = arith.constant 0 : index // CHECK: %[[C128_2:.*]] = arith.constant 128 : index diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp index 752c885e0b87b..2fcc7bcadb604 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -450,7 +450,9 @@ static void addPatternForTiling(MLIRContext *context, ArrayRef tileSizes, ArrayRef interchange = {}) { scf::SCFTilingOptions tilingOptions; - tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); + SmallVector tileSizesOfr = + getAsIndexOpFoldResult(context, tileSizes); + tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange); LinalgTransformationFilter filter(StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); patterns.add(context, tilingOptions, filter); @@ -462,7 +464,9 @@ static void addPatternForTileFuseAndYield(MLIRContext *context, ArrayRef tileSizes, ArrayRef interchange = {}) { scf::SCFTilingOptions tilingOptions; - tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); + SmallVector tileSizesOfr = + getAsIndexOpFoldResult(context, tileSizes); + tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange); LinalgTransformationFilter filter(StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); patterns.add( @@ -475,8 +479,10 @@ static void addPatternForTileAndFuse(MLIRContext *context, ArrayRef tileSizes, ArrayRef interchange = {}) { scf::SCFTileAndFuseOptions tileAndFuseOptions; - tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange( - interchange); + SmallVector tileSizesOfr = + getAsIndexOpFoldResult(context, tileSizes); + tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr) + .setInterchange(interchange); LinalgTransformationFilter filter(StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); patterns.add(