Skip to content

[mlir][TilingInterface] Make the tiling set tile sizes function use OpFoldResult. #66566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace mlir {
namespace scf {

using SCFTileSizeComputationFunction =
std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;

/// Options to use to control tiling.
struct SCFTilingOptions {
Expand All @@ -40,17 +40,10 @@ struct SCFTilingOptions {
tileSizeComputationFunction = std::move(fun);
return *this;
}
/// Set the `tileSizeComputationFunction` to return the values `ts`. The
/// values must not fold away when tiling. Otherwise, use a more robust
/// `tileSizeComputationFunction`.
SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
return *this;
}
/// Convenience function to set the `tileSizeComputationFunction` to a
/// function that computes tile sizes at the point they are needed. Allows
/// proper interaction with folding.
SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);

/// The interchange vector to reorder the tiled loops.
SmallVector<int64_t> interchangeVector = {};
Expand Down
24 changes: 12 additions & 12 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,9 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,

scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
tilingOptions = tilingOptions.setTileSizes(tileSizes);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions = tilingOptions;
LogicalResult result = applyTilingToAll(
Expand Down Expand Up @@ -923,7 +925,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
auto diag = mlir::emitSilenceableFailure(getLoc())
<< "could not find next producer to fuse into container";
<< "could not find next producer to fuse into container";
diag.attachNote(containingOp->getLoc()) << "containing op";
return diag;
}
Expand Down Expand Up @@ -1999,7 +2001,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
transform::TransformState &state) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
SmallVector<Value, 4> tileSizes;
SmallVector<OpFoldResult> tileSizes;
Location loc = target.getLoc();
SmallVector<OpFoldResult> allShapeSizes =
target.createFlatListOfOperandDims(b, loc);
Expand All @@ -2012,9 +2014,8 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
// If the shape size is dynamic, tile by 1.
// Otherwise, do not tile (i.e. tile size 0).
for (OpFoldResult shapeSize : shapeSizes) {
tileSizes.push_back(getConstantIntValue(shapeSize)
? b.create<arith::ConstantIndexOp>(loc, 0)
: b.create<arith::ConstantIndexOp>(loc, 1));
tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
: b.getIndexAttr(1));
}
return tileSizes;
});
Expand Down Expand Up @@ -2549,7 +2550,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
if (!tileSizes.empty()) {
tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
Operation *) {
SmallVector<Value, 4> sizes;
SmallVector<OpFoldResult> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;

Expand All @@ -2560,10 +2561,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
getLoc(), attr.cast<IntegerAttr>().getInt());
Value vscale =
b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
sizes.push_back(b.create<arith::MulIOp>(getLoc(), val, vscale));
sizes.push_back(
b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
} else {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), cast<IntegerAttr>(attr).getInt()));
sizes.push_back(attr);
}
continue;
}
Expand All @@ -2573,8 +2574,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
assert((dynamicSizes.empty() ^ params.empty()) &&
"expected either dynamic sizes or parameters");
if (!params.empty()) {
sizes.push_back(
b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
sizes.push_back(b.getIndexAttr(params[index]));
} else {
sizes.push_back(dynamicSizes[index]->getResult(0));
}
Expand Down
50 changes: 21 additions & 29 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,11 @@
using namespace mlir;

scf::SCFTilingOptions &
scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
assert(!tileSizeComputationFunction && "tile sizes already set");
SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
auto tileSizes = llvm::to_vector(ts);
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(
&op->getParentWithTrait<OpTrait::IsIsolatedFromAbove>()
->getRegion(0)
.front());
return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
return v;
}));
return tileSizes;
};
return *this;
}
Expand Down Expand Up @@ -108,17 +100,16 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,

/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
/// the
/// tile processed within the inner most loop.
static SmallVector<scf::ForOp>
generateTileLoopNest(OpBuilder &builder, Location loc,
ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes) {
static SmallVector<scf::ForOp> generateTileLoopNest(
OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes) {
assert(!loopRanges.empty() && "expected at least one loop range");
assert(loopRanges.size() == tileSizeVals.size() &&
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
OpBuilder::InsertionGuard guard(builder);
SmallVector<scf::ForOp> loops;
Expand All @@ -130,7 +121,8 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
Value size =
getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
Value tileSize = tileSizeVals[loopRange.index()];
Value tileSize = getValueOrCreateConstantIndexOp(
builder, loc, tileSizes[loopRange.index()]);
// No loops if tile size is zero. Set offset and size to the loop
// offset and size.
if (matchPattern(tileSize, m_Zero())) {
Expand Down Expand Up @@ -296,10 +288,10 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
// dimensions.
SmallVector<Value> tileSizeVector =
SmallVector<OpFoldResult> tileSizeVector =
options.tileSizeComputationFunction(rewriter, op);
if (tileSizeVector.size() < iterationDomain.size()) {
auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto zero = rewriter.getIndexAttr(0);
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
}

Expand Down Expand Up @@ -402,17 +394,17 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
FailureOr<scf::SCFReductionTilingResult>
mlir::scf::tileReductionUsingScf(RewriterBase &b,
PartialReductionOpInterface op,
ArrayRef<OpFoldResult> tileSize) {
ArrayRef<OpFoldResult> tileSizes) {
Location loc = op.getLoc();
// Ops implementing PartialReductionOpInterface are expected to implement
// TilingInterface.
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
SmallVector<Value> tileSizeVector =
getValueOrCreateConstantIndexOp(b, loc, tileSize);
if (tileSizeVector.size() < iterationDomain.size()) {
auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
auto tileSizesVector = llvm::to_vector(tileSizes);
if (tileSizesVector.size() < iterationDomain.size()) {
auto zero = b.getIndexAttr(0);
tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
zero);
}
if (op->getNumResults() != 1)
return b.notifyMatchFailure(
Expand All @@ -429,15 +421,15 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,

// 1. create the inital tensor value.
FailureOr<Operation *> identityTensor =
op.generateInitialTensorForPartialReduction(b, loc, tileSize,
op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
reductionDims);
if (failed(identityTensor))
return b.notifyMatchFailure(op,
"cannot create a tensor of identity value.");
// 2. Create the nested loops.
SmallVector<OpFoldResult> offsets, sizes;
SmallVector<scf::ForOp> loops = generateTileLoopNest(
b, loc, iterationDomain, tileSizeVector, offsets, sizes);
b, loc, iterationDomain, tileSizesVector, offsets, sizes);

// 3. Generate the tiled implementation within the inner most loop.
b.setInsertionPoint(loops.back().getBody()->getTerminator());
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Linalg/transform-op-tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,16 @@ transform.sequence failures(propagate) {
// -----

// CHECK-LABEL: func.func @scalable_and_fixed_length_tile
// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[VS:.*]] = vector.vscale
// CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C128:.*]] = arith.constant 128 : index
// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK: %[[C128_1:.*]] = arith.constant 128 : index
// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
// CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
// CHECK: %[[C128_2:.*]] = arith.constant 128 : index
Expand Down
14 changes: 10 additions & 4 deletions mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,9 @@ static void addPatternForTiling(MLIRContext *context,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(context, tileSizes);
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
Expand All @@ -462,7 +464,9 @@ static void addPatternForTileFuseAndYield(MLIRContext *context,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(context, tileSizes);
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
Expand All @@ -475,8 +479,10 @@ static void addPatternForTileAndFuse(MLIRContext *context,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> interchange = {}) {
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
interchange);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(context, tileSizes);
tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr)
.setInterchange(interchange);
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(
Expand Down