Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 7181785

Browse files
authoredJun 23, 2025··
[mlir][PartialReductionTilingInterface] Generalize implementation of tileUsingSCF for ReductionTilingStrategy::PartialOuterReduction. (#143467)
This is a precursor to generalizing the `tileUsingSCF` to handle `ReductionTilingStrategy::PartialOuterParallel` strategy. This change itself is generalizing/refactoring the current implementation that supports only `ReductionTilingStrategy::PartialOuterReduction`. Changes in this PR - Move the `ReductionTilingStrategy` enum out of `scf::SCFTilingOptions` and make them visible to `TilingInterface`. - `PartialTilingInterface` changes - Pass the `tilingStrategy` used for partial reduction to `tileToPartialReduction`. - Pass the reduction dimension along as `const llvm::SetVector<unsigned> &`. - Allow `scf::SCFTilingOptions` to set the reduction dimensions that are to be tiled. - Change `structured.tiled_reduction_using_for` to allow specification of the reduction dimensions to be partially tiled. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent e80acd4 commit 7181785

File tree

9 files changed

+433
-251
lines changed

9 files changed

+433
-251
lines changed
 

‎mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1859,6 +1859,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
18591859
- the result-combining op,
18601860
- the parent `for` op.
18611861

1862+
The `reduction_dims` can be used to specify the subset of reduction dimensions
1863+
of the operation to tile. If left unspecified, all reduction dimensions are
1864+
tiled.
1865+
18621866
#### Example:
18631867

18641868
```
@@ -1909,7 +1913,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
19091913

19101914
// TODO: support mixed static-dynamic (see TileUsingForallOp).
19111915
let arguments = (ins TransformHandleTypeInterface:$target,
1912-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
1916+
DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
1917+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
19131918
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
19141919
TransformHandleTypeInterface:$split_op,
19151920
TransformHandleTypeInterface:$combining_op,
@@ -1922,6 +1927,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
19221927

19231928
let assemblyFormat = [{
19241929
$target
1930+
(`reduction_dims` `=` $reduction_dims^)?
19251931
`by` `tile_sizes` `=` $tile_sizes
19261932
attr-dict
19271933
`:` functional-type(operands, results)

‎mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -85,28 +85,21 @@ struct SCFTilingOptions {
8585
return *this;
8686
}
8787

88+
/// Specify mapping of loops to devices. This is only respected when the loop
89+
/// constructs support such a mapping (like `scf.forall`). Will be ignored
90+
/// when using loop constructs that dont support such a mapping (like
91+
/// `scf.for`)
92+
SmallVector<Attribute> mappingVector = {};
93+
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
94+
mappingVector = llvm::to_vector(mapping);
95+
return *this;
96+
}
97+
98+
//-------------------------------------------------------------------------//
99+
// Options related reduction tiling
100+
//-------------------------------------------------------------------------//
101+
88102
/// Specify how reduction dimensions should be tiled.
89-
///
90-
/// Tiling can be thought of as splitting a dimension into 2 and materializing
91-
/// the outer dimension as a loop:
92-
///
93-
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
94-
///
95-
/// For parallel dimensions, the split can only happen in one way, with both
96-
/// dimensions being parallel. For reduction dimensions however, there is a
97-
/// choice in how we split the reduction dimension. This enum exposes this
98-
/// choice.
99-
enum class ReductionTilingStrategy {
100-
// [reduction] -> [reduction1, reduction2]
101-
// -> loop[reduction1] { [reduction2] }
102-
FullReduction,
103-
// [reduction] -> [reduction1, parallel2]
104-
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
105-
PartialReductionOuterReduction,
106-
// [reduction] -> [parallel1, reduction2]
107-
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
108-
PartialReductionOuterParallel
109-
};
110103
ReductionTilingStrategy reductionStrategy =
111104
ReductionTilingStrategy::FullReduction;
112105
SCFTilingOptions &
@@ -115,13 +108,13 @@ struct SCFTilingOptions {
115108
return *this;
116109
}
117110

118-
/// Specify mapping of loops to devices. This is only respected when the loop
119-
/// constructs support such a mapping (like `scf.forall`). Will be ignored
120-
/// when using loop constructs that dont support such a mapping (like
121-
/// `scf.for`)
122-
SmallVector<Attribute> mappingVector = {};
123-
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
124-
mappingVector = llvm::to_vector(mapping);
111+
/// Specify the reduction dimensions to be tiled. Note that this needs to be
112+
/// specified. If left unspecified, then none of the reduction dimensions are
113+
/// tiled.
114+
SetVector<unsigned> reductionDims;
115+
SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
116+
reductionDims.clear();
117+
reductionDims.insert(dims.begin(), dims.end());
125118
return *this;
126119
}
127120
};

‎mlir/include/mlir/Interfaces/TilingInterface.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,27 @@ struct TilingResult {
3636
SmallVector<Operation *> generatedSlices;
3737
};
3838

39+
/// Tiling can be thought of as splitting a dimension into 2 and
40+
/// materializing the outer dimension as a loop:
41+
///
42+
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
43+
///
44+
/// For parallel dimensions, the split can only happen in one way, with both
45+
/// dimensions being parallel. For reduction dimensions however, there is a
46+
/// choice in how we split the reduction dimension. This enum exposes this
47+
/// choice.
48+
enum class ReductionTilingStrategy {
49+
// [reduction] -> [reduction1, reduction2]
50+
// -> loop[reduction1] { [reduction2] }
51+
FullReduction,
52+
// [reduction] -> [reduction1, parallel2]
53+
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
54+
PartialReductionOuterReduction,
55+
// [reduction] -> [parallel1, reduction2]
56+
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
57+
PartialReductionOuterParallel
58+
};
59+
3960
/// Container for the result of merge operation of tiling.
4061
/// - `mergeOps` contains operations created during the merge.
4162
/// - `replacements` contains the values that represents the result of the

‎mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def PartialReductionOpInterface :
384384
"::mlir::OpBuilder &":$b,
385385
"Location":$loc,
386386
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
387-
"::mlir::ArrayRef<int>":$reductionDim),
387+
"const ::mlir::SetVector<unsigned> &":$reductionDims),
388388
/*methodBody=*/"",
389389
/*defaultImplementation=*/[{
390390
return failure();
@@ -402,10 +402,11 @@ def PartialReductionOpInterface :
402402
/*args=*/(ins
403403
"::mlir::OpBuilder &":$b,
404404
"Location ":$loc,
405+
"::mlir::ReductionTilingStrategy":$tilingStrategy,
405406
"ValueRange":$init,
406407
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
407408
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
408-
"::mlir::ArrayRef<int>":$reductionDims),
409+
"const ::llvm::SetVector<unsigned> &":$reductionDims),
409410
/*methodBody=*/"",
410411
/*defaultImplementation=*/[{
411412
return failure();
@@ -423,7 +424,7 @@ def PartialReductionOpInterface :
423424
"::mlir::OpBuilder &":$b,
424425
"Location ":$loc,
425426
"ValueRange":$partialReduce,
426-
"::mlir::ArrayRef<int>":$reductionDim),
427+
"const ::mlir::SetVector<unsigned> &":$reductionDims),
427428
/*methodBody=*/"",
428429
/*defaultImplementation=*/[{
429430
return failure();
@@ -443,9 +444,9 @@ def PartialReductionOpInterface :
443444
"unsigned":$resultNumber,
444445
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
445446
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
447+
"const ::mlir::SetVector<unsigned> &":$reductionDims,
446448
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
447-
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
448-
"::mlir::ArrayRef<int>":$reductionDims),
449+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
449450
/*methodBody=*/"",
450451
/*defaultImplementation=*/[{
451452
return failure();

‎mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2947,10 +2947,11 @@ void transform::TileReductionUsingForOp::build(
29472947
// TODO: support mixed static-dynamic (see TileUsingForallOp).
29482948
MLIRContext *ctx = builder.getContext();
29492949
auto opTy = transform::AnyOpType::get(ctx);
2950-
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2950+
auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
29512951
build(builder, result,
29522952
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
29532953
/*target=*/target,
2954+
/*reduction_dims=*/nullptr,
29542955
/*tile_sizes=*/staticTileSizesAttr);
29552956
}
29562957

@@ -2966,12 +2967,30 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
29662967
target->getLoc(),
29672968
"Operation should implement PartialReductionOpInterface");
29682969
}
2969-
FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2970-
rewriter, partialReductionOp,
2971-
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
29722970

2973-
if (failed(result))
2974-
return emitDefaultSilenceableFailure(target);
2971+
SmallVector<unsigned> reductionDims =
2972+
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
2973+
if (reductionDims.empty()) {
2974+
for (auto [idx, iteratorType] :
2975+
llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
2976+
if (iteratorType == utils::IteratorType::reduction)
2977+
reductionDims.push_back(idx);
2978+
}
2979+
}
2980+
2981+
scf::SCFTilingOptions options;
2982+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
2983+
options.setReductionTilingStrategy(
2984+
ReductionTilingStrategy::PartialReductionOuterReduction);
2985+
options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
2986+
options.setReductionDims(reductionDims);
2987+
FailureOr<scf::SCFTilingResult> result =
2988+
scf::tileUsingSCF(rewriter, partialReductionOp, options);
2989+
2990+
if (failed(result)) {
2991+
return emitSilenceableFailure(getLoc(),
2992+
"failed to tile using partial reduction");
2993+
}
29752994
rewriter.replaceOp(target, result->replacements);
29762995
for (Value initValue : result->initialValues)
29772996
results.push_back(initValue.getDefiningOp());

‎mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
109109
}
110110

111111
FailureOr<StaticContinuousTileSizeSpecification>
112-
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
113-
unsigned dimension,
112+
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
114113
unsigned targetSize) {
115114

116115
assert(!op.hasDynamicShape() &&
@@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
183182

184183
// Find the trip count of the iteration space dimension for which the tile
185184
// sizes are computed.
186-
Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
187-
loopRanges[dimension].size);
185+
Value loopRange =
186+
getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
188187
ContinuousTileSizeSpecification spec;
189188

190189
// Compute the tile sizes and the respective numbers of tiles.
@@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
633632
if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
634633
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
635634
"many elements as number of threads");
636-
int reductionDim = static_cast<int>(redDims.front());
637635

638636
if (redDims.front() >= numThreads.size())
639637
return b.notifyMatchFailure(
640638
op, "reduction dimension must be mapped to threads");
641639

642640
// 1. Create the inital tensor value.
641+
unsigned reductionDim = redDims.front();
642+
SetVector<unsigned> reductionDims;
643+
reductionDims.insert(reductionDim);
643644
FailureOr<SmallVector<Value>> maybeInitTensors =
644645
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
645-
reductionDim);
646+
reductionDims);
646647
if (failed(maybeInitTensors))
647648
return b.notifyMatchFailure(
648649
op, "Failed to create inital tensors for partial reduction");
@@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
780781
// 7. Merge the partial reductions.
781782
b.setInsertionPointAfter(forallOp);
782783
FailureOr<MergeResult> mergeResult =
783-
op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
784+
op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
784785
if (failed(mergeResult)) {
785786
return failure();
786787
}

‎mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 93 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2020
#include "mlir/Dialect/Utils/IndexingUtils.h"
2121
#include "mlir/Dialect/Utils/StaticValueUtils.h"
22+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2223
#include "mlir/Interfaces/TilingInterface.h"
2324
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2425
#include <optional>
@@ -327,23 +328,48 @@ struct LinalgOpTilingInterface
327328
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
328329
//===----------------------------------------------------------------------===//
329330

330-
/// Return an AffineMap for a partial result for the given result number,
331-
/// assuming the partial tiling strategy is outer-reduction loop +
332-
/// inner-parallel tile. The returned AffineMap can be used as the replacement
333-
/// AffineMap for the inner-parallel tile linalg op for the given result number.
334-
///
335-
/// The new AffineMap is the old AffineMap with reduction dimensions appended
336-
/// at end.
337-
static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
338-
ArrayRef<int> reductionDims,
339-
unsigned resultNumber) {
340-
AffineMap map =
341-
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
342-
for (int redPos : reductionDims) {
343-
map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
344-
map.getNumResults());
331+
/// Return an AffineMaps to use for the `outs` operands of the linalg op
332+
/// generated for partial results. The new AffineMap is the AffineMap of the
333+
/// untiled op with reduction dimensions appended at end in order in which they
334+
/// were specified during tiling.
335+
static SmallVector<AffineMap>
336+
getPartialResultAffineMaps(LinalgOp linalgOp,
337+
const SetVector<unsigned> &reductionDims) {
338+
auto partialReductionMaps = llvm::map_to_vector(
339+
linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
340+
AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
341+
for (auto redPos : reductionDims) {
342+
map =
343+
map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
344+
map.getNumResults());
345+
}
346+
return map;
347+
});
348+
return partialReductionMaps;
349+
}
350+
351+
/// Return the slice of the `initValue` to use as input to the partial reduction
352+
/// op generated.
353+
static Operation *getInitSliceForOuterReduction(
354+
OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
355+
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
356+
AffineMap partialReductionMap) {
357+
int64_t initRank = partialReductionMap.getNumResults();
358+
SmallVector<OpFoldResult> initOffsets, initSizes;
359+
SmallVector<OpFoldResult> initStrides(initRank, b.getIndexAttr(1));
360+
for (AffineExpr dimExpr : partialReductionMap.getResults()) {
361+
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
362+
if (reductionDims.contains(dim)) {
363+
initOffsets.push_back(b.getIndexAttr(0));
364+
} else {
365+
initOffsets.push_back(offsets[dim]);
366+
}
367+
initSizes.push_back(sizes[dim]);
345368
}
346-
return map;
369+
// TODO: Use SubsetExtractOpInterface here once available.
370+
auto extractSlice = b.create<tensor::ExtractSliceOp>(
371+
loc, initValue, initOffsets, initSizes, initStrides);
372+
return extractSlice;
347373
}
348374

349375
/// External model implementation of PartialReductionInterface for
@@ -354,13 +380,16 @@ struct LinalgOpPartialReductionInterface
354380
LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
355381
FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
356382
Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
357-
ArrayRef<int> reductionDims) const {
383+
const SetVector<unsigned> &reductionDims) const {
358384
auto linalgOp = cast<LinalgOp>(op);
359-
OpBuilder::InsertionGuard guard(b);
360385

386+
OpBuilder::InsertionGuard guard(b);
361387
if (linalgOp.hasPureBufferSemantics())
362388
return op->emitOpError("expected operation to have tensor semantics");
363389

390+
SmallVector<AffineMap> partialResultMaps =
391+
getPartialResultAffineMaps(linalgOp, reductionDims);
392+
364393
// LinalgOp implements TilingInterface.
365394
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
366395
SmallVector<OpFoldResult> shape =
@@ -377,8 +406,8 @@ struct LinalgOpPartialReductionInterface
377406
}
378407

379408
SmallVector<Value> inits;
380-
for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
381-
++initIdx) {
409+
for (auto [initIdx, result, partialMap] :
410+
llvm::enumerate(linalgOp->getResults(), partialResultMaps)) {
382411
SmallVector<Operation *, 4> combinerOps;
383412
if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
384413
combinerOps) ||
@@ -392,16 +421,13 @@ struct LinalgOpPartialReductionInterface
392421
"Failed to get an identity value for the reduction operation.");
393422

394423
// Append the new partial result dimensions.
395-
AffineMap partialMap =
396-
getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
397424
SmallVector<OpFoldResult> partialResultShape;
398425
for (AffineExpr dimExpr : partialMap.getResults()) {
399426
auto dim = cast<AffineDimExpr>(dimExpr);
400427
partialResultShape.push_back(tiledShape[dim.getPosition()]);
401428
}
402429

403-
Type elType =
404-
getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
430+
Type elType = getElementTypeOrSelf(result.getType());
405431
Value emptyTensor =
406432
b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
407433
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
@@ -415,23 +441,25 @@ struct LinalgOpPartialReductionInterface
415441

416442
FailureOr<TilingResult>
417443
tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
444+
ReductionTilingStrategy tilingStrategy,
418445
ValueRange init, ArrayRef<OpFoldResult> offsets,
419446
ArrayRef<OpFoldResult> sizes,
420-
ArrayRef<int> reductionDims) const {
447+
const SetVector<unsigned> &reductionDims) const {
448+
if (tilingStrategy !=
449+
ReductionTilingStrategy::PartialReductionOuterReduction) {
450+
// TODO: Add support for `PartialReductionOuterParallel` strategy.
451+
return op->emitOpError("unsupported partial reduction tiling with "
452+
"`PartialReductionOuterParallel` strategy");
453+
}
421454
OpBuilder::InsertionGuard guard(b);
422455
auto linalgOp = cast<LinalgOp>(op);
423456

457+
SmallVector<AffineMap> partialReductionMaps =
458+
getPartialResultAffineMaps(linalgOp, reductionDims);
459+
424460
// Step 1. Extend init maps to have reduction dimension dims, since we
425461
// are converting them to parallel dimensions.
426-
SmallVector<AffineMap> newInitMaps;
427-
newInitMaps.reserve(linalgOp.getNumDpsInits());
428-
for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
429-
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
430-
// this with a for range loop when we have it.
431-
AffineMap newMap =
432-
getPartialResultAffineMap(linalgOp, reductionDims, idx);
433-
newInitMaps.push_back(newMap);
434-
}
462+
SmallVector<AffineMap> newInitMaps = partialReductionMaps;
435463

436464
// Step 2a: Extract a slice of the input operands.
437465
SmallVector<Value> tiledInputs = makeTiledShapes(
@@ -443,31 +471,21 @@ struct LinalgOpPartialReductionInterface
443471

444472
// Step 2b: Extract a slice of the init operands.
445473
SmallVector<Value, 1> tiledInits;
446-
for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) {
447-
int64_t initRank = valueMap.getNumResults();
448-
SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0));
449-
SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1));
450-
SmallVector<OpFoldResult> initSizes;
451-
for (AffineExpr dimExpr : valueMap.getResults()) {
452-
auto dim = cast<AffineDimExpr>(dimExpr);
453-
initSizes.push_back(sizes[dim.getPosition()]);
454-
}
455-
// TODO: Use SubsetExtractOpInterface here once available.
456-
auto extractSlice = b.create<tensor::ExtractSliceOp>(
457-
loc, valueToTile, initOffset, initSizes, initStride);
458-
tiledInits.push_back(extractSlice);
459-
generatedSlices.push_back(extractSlice);
474+
for (auto [partialReductionMap, valueToTile] :
475+
llvm::zip_equal(partialReductionMaps, init)) {
476+
Operation *sliceOp =
477+
getInitSliceForOuterReduction(b, loc, valueToTile, offsets, sizes,
478+
reductionDims, partialReductionMap);
479+
tiledInits.push_back(sliceOp->getResult(0));
480+
generatedSlices.push_back(sliceOp);
460481
}
461482

462483
// Update the indexing maps.
463484
SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
464-
// Change the init maps.
465-
for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
466-
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
467-
// this with a for range loop when we have it.
468-
OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
469-
int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
470-
newMaps[mapIdx] = newInitMaps[idx];
485+
for (auto [initOperand, newInitMap] :
486+
llvm::zip_equal(linalgOp.getDpsInitsMutable(), newInitMaps)) {
487+
int mapIdx = linalgOp.getIndexingMapIndex(&initOperand);
488+
newMaps[mapIdx] = newInitMap;
471489
}
472490

473491
// Step 3. Change the reduction dim iterator types.
@@ -477,9 +495,9 @@ struct LinalgOpPartialReductionInterface
477495
newIteratorTypes[dim] = utils::IteratorType::parallel;
478496

479497
// Step 4. Create the new generic op.
480-
auto genericOp =
481-
b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
482-
tiledInits, newMaps, newIteratorTypes);
498+
auto resultTypes = ValueRange(tiledInits).getTypes();
499+
auto genericOp = b.create<GenericOp>(loc, resultTypes, tiledInputs,
500+
tiledInits, newMaps, newIteratorTypes);
483501
IRMapping mapping;
484502
op->getRegion(0).cloneInto(&genericOp.getRegion(),
485503
genericOp.getRegion().begin(), mapping);
@@ -490,23 +508,24 @@ struct LinalgOpPartialReductionInterface
490508
generatedSlices};
491509
}
492510

493-
FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
494-
Location loc, ValueRange partialReduce,
495-
ArrayRef<int> reductionDims) const {
511+
FailureOr<MergeResult>
512+
mergeReductions(Operation *op, OpBuilder &b, Location loc,
513+
ValueRange partialReduce,
514+
const SetVector<unsigned> &reductionDims) const {
496515
auto linalgOp = cast<LinalgOp>(op);
516+
SmallVector<AffineMap> partialReductionMaps =
517+
getPartialResultAffineMaps(linalgOp, reductionDims);
497518

498519
// Permute the reduction dims as permuted by the partial result map.
499-
500-
int64_t numInits = linalgOp.getNumDpsInits();
501520
SmallVector<Operation *> mergeOperations;
502521
SmallVector<Value> replacements;
503-
for (int idx : llvm::seq(numInits)) {
522+
for (auto [idx, init, partialResult, partialMap] : llvm::enumerate(
523+
linalgOp.getDpsInits(), partialReduce, partialReductionMaps)) {
524+
unsigned initIdx = idx;
504525
// linalg.reduce's iteration space is the tiled result's iteration space
505526
// (and not the tiled operation's iteration space). To account for this,
506527
// permute the reduction dimensions based on the partial result map of the
507528
// tiled result.
508-
AffineMap partialMap =
509-
getPartialResultAffineMap(linalgOp, reductionDims, idx);
510529
SmallVector<int64_t> partialReductionDims;
511530
for (auto [resultNum, dimExpr] :
512531
llvm::enumerate(partialMap.getResults())) {
@@ -516,15 +535,13 @@ struct LinalgOpPartialReductionInterface
516535
}
517536
}
518537

519-
Value partialResult = partialReduce[idx];
520-
Value init = linalgOp.getDpsInits()[idx];
521-
522538
auto reduction = b.create<linalg::ReduceOp>(
523539
loc, partialResult, init, partialReductionDims,
524-
[&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
540+
[&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) {
525541
// Get the combiner op.
526542
SmallVector<Operation *, 4> combinerOps;
527-
matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
543+
matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
544+
combinerOps);
528545
Operation *clonedReductionOp = b.clone(*combinerOps[0]);
529546
// Combine the input at idx and output at numInits + idx.
530547
clonedReductionOp->setOperand(0, inputs[0]);
@@ -542,14 +559,14 @@ struct LinalgOpPartialReductionInterface
542559
LogicalResult getPartialResultTilePosition(
543560
Operation *op, OpBuilder &b, unsigned resultNumber,
544561
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
562+
const SetVector<unsigned> &reductionDims,
545563
SmallVector<OpFoldResult> &resultOffsets,
546-
SmallVector<OpFoldResult> &resultSizes,
547-
ArrayRef<int> reductionDims) const {
564+
SmallVector<OpFoldResult> &resultSizes) const {
548565
auto linalgOp = cast<LinalgOp>(op);
566+
SmallVector<AffineMap> partialReductionMaps =
567+
getPartialResultAffineMaps(linalgOp, reductionDims);
549568

550-
AffineMap partialMap =
551-
getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
552-
for (AffineExpr dimExpr : partialMap.getResults()) {
569+
for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults()) {
553570
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
554571
resultSizes.push_back(sizes[dim]);
555572

‎mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 87 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
7777
//===----------------------------------------------------------------------===//
7878

7979
/// Verify the tile size options are set in a consistent manner.
80-
static LogicalResult
81-
verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
82-
const scf::SCFTilingOptions &options) {
80+
static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc,
81+
const scf::SCFTilingOptions &options) {
8382
// Specifying number of threads is only supported on `scf.forall` op.
8483
if (options.numThreadsComputationFunction &&
8584
options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
@@ -156,7 +155,9 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
156155
}
157156

158157
/// Checks if any of the tiled loops are not parallel.
159-
static void checkSafeToTileToForall(TilingInterface op,
158+
static LogicalResult checkTileSizes(TilingInterface op,
159+
scf::SCFTilingOptions::LoopType loopType,
160+
ReductionTilingStrategy reductionStrategy,
160161
ArrayRef<OpFoldResult> tileSizes,
161162
ArrayRef<OpFoldResult> numThreads) {
162163
auto iterators = op.getLoopIteratorTypes();
@@ -165,28 +166,46 @@ static void checkSafeToTileToForall(TilingInterface op,
165166
assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
166167
"when specified, expected number of threads to use for each loop");
167168

169+
bool isParallelTiling = false, isReductionTiling = false;
168170
for (auto [index, iterator, tileSize] :
169171
llvm::enumerate(iterators, tileSizes)) {
170-
// If num threads is specified, check that it is greater than one only for
171-
// parallel dimensions.
172-
if (!numThreads.empty()) {
173-
if (std::optional<int64_t> constNumThreads =
174-
getConstantIntValue(numThreads[index])) {
175-
if (constNumThreads.value() > 1 &&
172+
if (!isConstantIntValue(tileSize, 0)) {
173+
isParallelTiling |= iterator == utils::IteratorType::parallel;
174+
isReductionTiling |= iterator == utils::IteratorType::reduction;
175+
}
176+
177+
if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
178+
reductionStrategy == ReductionTilingStrategy::FullReduction) {
179+
// If num threads is specified, check that it is greater than one only for
180+
// parallel dimensions.
181+
if (!numThreads.empty()) {
182+
if (std::optional<int64_t> constNumThreads =
183+
getConstantIntValue(numThreads[index])) {
184+
if (constNumThreads.value() > 1 &&
185+
iterator != utils::IteratorType::parallel) {
186+
op.emitWarning() << "tiling is not thread safe at axis #" << index;
187+
}
188+
}
189+
continue;
190+
}
191+
192+
if (std::optional<int64_t> constTileSize =
193+
getConstantIntValue(tileSize)) {
194+
if (constTileSize.value() > 0 &&
176195
iterator != utils::IteratorType::parallel) {
177196
op.emitWarning() << "tiling is not thread safe at axis #" << index;
178197
}
179198
}
180-
continue;
181199
}
200+
}
182201

183-
if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
184-
if (constTileSize.value() > 0 &&
185-
iterator != utils::IteratorType::parallel) {
186-
op.emitWarning() << "tiling is not thread safe at axis #" << index;
187-
}
188-
}
202+
if (isParallelTiling && isReductionTiling &&
203+
reductionStrategy != ReductionTilingStrategy::FullReduction) {
204+
return op->emitOpError(
205+
"combined parallel and reduction tiling is not supported with partial "
206+
"reduction tiling strategies");
189207
}
208+
return success();
190209
}
191210

192211
/// Check if `stride` evenly divides the trip count `size - offset`.
@@ -575,70 +594,41 @@ createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
575594
const scf::SCFTilingOptions &options) {
576595
SmallVector<Value> initTensors;
577596
Location loc = op->getLoc();
578-
switch (options.reductionStrategy) {
579-
case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
597+
if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
580598
if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
581599
return failure();
582600
return initTensors;
583-
case scf::SCFTilingOptions::ReductionTilingStrategy::
584-
PartialReductionOuterReduction: {
585-
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
586-
if (!redOp) {
587-
return rewriter.notifyMatchFailure(
588-
op, "PartialReductionOuterReduction tiling strategy is only supported"
589-
"for operations implementing PartialReductionOpInterface");
590-
}
591-
// Get reduction dimensions.
592-
// TODO: PartialReductionOpInterface should really query TilingInterface
593-
// itself and find reduction dimensions.
594-
SmallVector<int> reductionDims;
595-
for (auto [idx, iteratorType] :
596-
llvm::enumerate(op.getLoopIteratorTypes())) {
597-
if (iteratorType == utils::IteratorType::reduction)
598-
reductionDims.push_back(idx);
599-
}
600-
return redOp.generateInitialTensorForPartialReduction(
601-
rewriter, loc, tileSizes, reductionDims);
602601
}
603-
default:
604-
return rewriter.notifyMatchFailure(op,
605-
"unhandled reduction tiling strategy");
602+
603+
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
604+
if (!redOp) {
605+
return rewriter.notifyMatchFailure(
606+
op, "PartialReductionOuterReduction tiling strategy is only supported"
607+
"for operations implementing PartialReductionOpInterface");
606608
}
609+
return redOp.generateInitialTensorForPartialReduction(
610+
rewriter, loc, tileSizes, options.reductionDims);
607611
}
608612

609613
static FailureOr<TilingResult>
610614
getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
611615
ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
612616
ArrayRef<OpFoldResult> sizes,
613617
const scf::SCFTilingOptions &options) {
614-
switch (options.reductionStrategy) {
615-
case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
618+
if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
616619
return op.getTiledImplementation(rewriter, offsets, sizes);
617-
case scf::SCFTilingOptions::ReductionTilingStrategy::
618-
PartialReductionOuterReduction: {
619-
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
620-
if (!redOp) {
621-
return rewriter.notifyMatchFailure(
622-
op, "PartialReductionOuterReduction tiling strategy is only "
623-
"supported for operations "
624-
"implementing PartialReductionOpInterface");
625-
}
626-
// Get reduction dimensions.
627-
// TODO: PartialReductionOpInterface should really query TilingInterface
628-
// itself and find reduction dimensions.
629-
SmallVector<int> reductionDims;
630-
for (auto [idx, iteratorType] :
631-
llvm::enumerate(op.getLoopIteratorTypes())) {
632-
if (iteratorType == utils::IteratorType::reduction)
633-
reductionDims.push_back(idx);
634-
}
635-
return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
636-
offsets, sizes, reductionDims);
637620
}
638-
default:
639-
return rewriter.notifyMatchFailure(op,
640-
"unhandled reduction tiling strategy");
621+
622+
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
623+
if (!redOp) {
624+
return rewriter.notifyMatchFailure(
625+
op, "PartialReductionOuterReduction tiling strategy is only "
626+
"supported for operations "
627+
"implementing PartialReductionOpInterface");
641628
}
629+
return redOp.tileToPartialReduction(rewriter, op.getLoc(),
630+
options.reductionStrategy, regionIterArg,
631+
offsets, sizes, options.reductionDims);
642632
}
643633

644634
static LogicalResult
@@ -649,70 +639,37 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
649639
SmallVector<OpFoldResult> &resultSize,
650640
const scf::SCFTilingOptions &options) {
651641

652-
switch (options.reductionStrategy) {
653-
case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
642+
if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
654643
return op.getResultTilePosition(rewriter, index, offsets, sizes,
655644
resultOffset, resultSize);
656-
case scf::SCFTilingOptions::ReductionTilingStrategy::
657-
PartialReductionOuterReduction: {
658-
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
659-
if (!redOp) {
660-
return rewriter.notifyMatchFailure(
661-
op, "PartialReductionOuterReduction tiling strategy is only supported"
662-
"for operations implementing PartialReductionOpInterface");
663-
}
664-
// Get reduction dimensions.
665-
// TODO: PartialReductionOpInterface should really query TilingInterface
666-
// itself and find reduction dimensions.
667-
SmallVector<int> reductionDims;
668-
for (auto [idx, iteratorType] :
669-
llvm::enumerate(op.getLoopIteratorTypes())) {
670-
if (iteratorType == utils::IteratorType::reduction)
671-
reductionDims.push_back(idx);
672-
}
673-
return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
674-
resultOffset, resultSize,
675-
reductionDims);
676645
}
677-
default:
678-
return rewriter.notifyMatchFailure(op,
679-
"unhandled reduction tiling strategy");
646+
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
647+
if (!redOp) {
648+
return rewriter.notifyMatchFailure(
649+
op, "PartialReductionOuterReduction tiling strategy is only supported"
650+
"for operations implementing PartialReductionOpInterface");
680651
}
652+
return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
653+
options.reductionDims, resultOffset,
654+
resultSize);
681655
}
682656

683657
static FailureOr<MergeResult>
684658
mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
685659
ValueRange partialResults,
686660
const scf::SCFTilingOptions &options) {
687-
switch (options.reductionStrategy) {
688-
case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
689-
// No need to merge results for reduction tiling strategy.
690-
return MergeResult{{}, partialResults};
691-
case scf::SCFTilingOptions::ReductionTilingStrategy::
692-
PartialReductionOuterReduction: {
693-
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
694-
if (!redOp) {
695-
return rewriter.notifyMatchFailure(
696-
op, "PartialReductionOuterReduction tiling strategy is only "
697-
"supported for operations "
698-
"implementing PartialReductionOpInterface");
699-
}
700-
// Get reduction dimensions.
701-
// TODO: PartialReductionOpInterface should really query TilingInterface
702-
// itself and find reduction dimensions.
703-
SmallVector<int> reductionDims;
704-
for (auto [idx, iteratorType] :
705-
llvm::enumerate(op.getLoopIteratorTypes())) {
706-
if (iteratorType == utils::IteratorType::reduction)
707-
reductionDims.push_back(idx);
708-
}
709-
return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
710-
reductionDims);
711-
}
712-
default:
713-
return rewriter.notifyMatchFailure(op,
714-
"unhandled reduction tiling strategy");
661+
assert(options.reductionStrategy != ReductionTilingStrategy::FullReduction &&
662+
"expected merge to be called for only partial reduction cases");
663+
664+
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
665+
if (!redOp) {
666+
return rewriter.notifyMatchFailure(
667+
op, "PartialReductionOuterReduction tiling strategy is only "
668+
"supported for operations "
669+
"implementing PartialReductionOpInterface");
715670
}
671+
return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
672+
options.reductionDims);
716673
}
717674

718675
/// Append the specified additional `newInitOperands` operands to the
@@ -932,7 +889,7 @@ static LogicalResult addInitOperandsToLoopNest(
932889
FailureOr<scf::SCFTilingResult>
933890
mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
934891
const scf::SCFTilingOptions &options) {
935-
if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
892+
if (failed(verifyOptions(rewriter, op.getLoc(), options))) {
936893
return failure();
937894
}
938895

@@ -949,8 +906,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
949906

950907
// Check if it is safe to tile. This is hold over from previous iterations
951908
// of tile to for-all. Consider dropping it.
952-
if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
953-
checkSafeToTileToForall(op, tileSizes, numThreads);
909+
if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy,
910+
tileSizes, numThreads))) {
911+
return failure();
954912
}
955913

956914
// 3. If there is an interchange specified, permute the iteration domain and
@@ -1073,8 +1031,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
10731031
[](OpResult r) -> Value { return r; });
10741032

10751033
// For the full reduction case, there is nothing more to do.
1076-
if (options.reductionStrategy ==
1077-
scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) {
1034+
if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
10781035
return scf::SCFTilingResult{
10791036
tilingResult->tiledOps, initTensors, loops, loopResults,
10801037
tilingResult->generatedSlices, {}};
@@ -1102,9 +1059,13 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
11021059
scf::SCFTilingOptions options;
11031060
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
11041061
options.setReductionTilingStrategy(
1105-
scf::SCFTilingOptions::ReductionTilingStrategy::
1106-
PartialReductionOuterReduction);
1062+
ReductionTilingStrategy::PartialReductionOuterReduction);
11071063
options.setTileSizes(tileSize);
1064+
SmallVector<unsigned> reductionDims;
1065+
for (auto [index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes()))
1066+
if (iteratorType == utils::IteratorType::reduction)
1067+
reductionDims.push_back(index);
1068+
options.setReductionDims(reductionDims);
11081069
return tileUsingSCF(b, op, options);
11091070
}
11101071

‎mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Lines changed: 165 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,6 @@ module attributes {transform.with_named_sequence} {
343343
module {
344344
func.func @fail_for_float_neutral(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
345345
// expected-error @below {{'linalg.generic' op Failed to get an identity value for the reduction operation.}}
346-
// expected-note @below {{when applied to this op}}
347346
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
348347
^bb0(%in: f32, %out: f32):
349348
%1 = llvm.fmul %in, %in : f32
@@ -355,7 +354,7 @@ module {
355354
module attributes {transform.with_named_sequence} {
356355
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
357356
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
358-
// expected-error @below {{transform.structured.tile_reduction_using_for failed to apply}}
357+
// expected-error @below {{failed to tile using partial reduction}}
359358
%fill_op, %split_linalg_op, %combining_linalg_op, %for_op = transform.structured.tile_reduction_using_for %0 by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
360359
transform.yield
361360
}
@@ -480,3 +479,167 @@ module attributes {transform.with_named_sequence} {
480479
// CHECK: }
481480
// CHECK: linalg.reduce
482481
// CHECK: return
482+
483+
// -----
484+
485+
// Check that only one of the reduction dimension can be tiled (in this case outer).
486+
487+
#map = affine_map<(d0, d1, d2) -> (d1, d2)>
488+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
489+
#map2 = affine_map<(d0, d1, d2) -> (d0)>
490+
module {
491+
func.func @reduction_tile_single_of_multiple_reduction_outer(
492+
%arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
493+
%0 = linalg.generic {
494+
indexing_maps = [#map, #map1, #map2],
495+
iterator_types = ["parallel", "reduction", "reduction"]}
496+
ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
497+
^bb0(%in: f32, %in_0: f32, %out: f32):
498+
%1 = arith.mulf %in, %in_0 : f32
499+
%2 = arith.addf %1, %out : f32
500+
linalg.yield %2 : f32
501+
} -> tensor<4096xf32>
502+
return %0 : tensor<4096xf32>
503+
}
504+
module attributes {transform.with_named_sequence} {
505+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
506+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
507+
%fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
508+
transform.structured.tile_reduction_using_for %0 reduction_dims = [1] by tile_sizes = [0, 2]
509+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
510+
transform.yield
511+
}
512+
}
513+
}
514+
// CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
515+
// CHECK: @reduction_tile_single_of_multiple_reduction_outer(
516+
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
517+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
518+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
519+
// CHECK-DAG: %[[C86:.+]] = arith.constant 86 : index
520+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<4096x2xf32>
521+
// CHECK: %[[FILL:.+]] = linalg.fill
522+
// CHECK-SAME: outs(%[[EMPTY]] :
523+
// CHECK: %[[RESULT:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[C86]] step %[[C2]]
524+
// CHECK-SAME: iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
525+
// CHECK: %[[PARTIAL_RESULT:.+]] = linalg.generic
526+
// CHECK-SAME: indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
527+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
528+
// CHECK-SAME: outs(%[[ITER_ARG]] :
529+
// CHECK: scf.yield %[[PARTIAL_RESULT]]
530+
// CHECK: %[[REDUCE:.+]] = linalg.reduce
531+
// CHECK-SAME: ins(%[[RESULT]] :
532+
// CHECK-SAME: outs(%[[INIT]] :
533+
// CHECK-SAME: dimensions = [1]
534+
// CHECK: return %[[REDUCE]]
535+
536+
// -----
537+
538+
// Check that only one of the reduction dimension can be tiled (in this case inner).
539+
540+
#map = affine_map<(d0, d1, d2) -> (d1, d2)>
541+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
542+
#map2 = affine_map<(d0, d1, d2) -> (d0)>
543+
module {
544+
func.func @reduction_tile_single_of_multiple_reduction_inner(
545+
%arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
546+
%0 = linalg.generic {
547+
indexing_maps = [#map, #map1, #map2],
548+
iterator_types = ["parallel", "reduction", "reduction"]}
549+
ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
550+
^bb0(%in: f32, %in_0: f32, %out: f32):
551+
%1 = arith.mulf %in, %in_0 : f32
552+
%2 = arith.addf %1, %out : f32
553+
linalg.yield %2 : f32
554+
} -> tensor<4096xf32>
555+
return %0 : tensor<4096xf32>
556+
}
557+
module attributes {transform.with_named_sequence} {
558+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
559+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
560+
%fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
561+
transform.structured.tile_reduction_using_for %0 reduction_dims = [2] by tile_sizes = [0, 0, 64]
562+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
563+
transform.yield
564+
}
565+
}
566+
}
567+
// CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
568+
// CHECK: @reduction_tile_single_of_multiple_reduction_inner(
569+
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
570+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
571+
// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
572+
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
573+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<4096x64xf32>
574+
// CHECK: %[[FILL:.+]] = linalg.fill
575+
// CHECK-SAME: outs(%[[EMPTY]] :
576+
// CHECK: %[[RESULT:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C64]]
577+
// CHECK-SAME: iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
578+
// CHECK: %[[PARTIAL_RESULT:.+]] = linalg.generic
579+
// CHECK-SAME: indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
580+
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"]
581+
// CHECK-SAME: outs(%[[ITER_ARG]] :
582+
// CHECK: scf.yield %[[PARTIAL_RESULT]]
583+
// CHECK: %[[REDUCE:.+]] = linalg.reduce
584+
// CHECK-SAME: ins(%[[RESULT]] :
585+
// CHECK-SAME: outs(%[[INIT]] :
586+
// CHECK-SAME: dimensions = [1]
587+
// CHECK: return %[[REDUCE]]
588+
589+
// -----
590+
591+
// Check that both the reduction dimensions are tiled but the dimensions in the output are swapped.
592+
593+
#map = affine_map<(d0, d1, d2) -> (d1, d2)>
594+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
595+
#map2 = affine_map<(d0, d1, d2) -> (d0)>
596+
module {
597+
func.func @reduction_tile_single_of_multiple_reduction_reversed(
598+
%arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
599+
%0 = linalg.generic {
600+
indexing_maps = [#map, #map1, #map2],
601+
iterator_types = ["parallel", "reduction", "reduction"]}
602+
ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
603+
^bb0(%in: f32, %in_0: f32, %out: f32):
604+
%1 = arith.mulf %in, %in_0 : f32
605+
%2 = arith.addf %1, %out : f32
606+
linalg.yield %2 : f32
607+
} -> tensor<4096xf32>
608+
return %0 : tensor<4096xf32>
609+
}
610+
module attributes {transform.with_named_sequence} {
611+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
612+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
613+
%fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
614+
transform.structured.tile_reduction_using_for %0 reduction_dims = [2, 1] by tile_sizes = [0, 2, 64]
615+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
616+
transform.yield
617+
}
618+
}
619+
}
620+
// CHECK: #[[INIT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
621+
// CHECK: @reduction_tile_single_of_multiple_reduction_reversed(
622+
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<4096xf32>
623+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
624+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
625+
// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
626+
// CHECK-DAG: %[[C86:.+]] = arith.constant 86 : index
627+
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
628+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<4096x64x2xf32>
629+
// CHECK: %[[FILL:.+]] = linalg.fill
630+
// CHECK-SAME: outs(%[[EMPTY]] :
631+
// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C86]] step %[[C2]]
632+
// CHECK-SAME: iter_args(%[[ITER_ARG:.+]] = %[[FILL]])
633+
// CHECK: %[[RESULT0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C64]]
634+
// CHECK-SAME: iter_args(%[[ITER_ARG0:.+]] = %[[ITER_ARG]])
635+
// CHECK: %[[PARTIAL_RESULT:.+]] = linalg.generic
636+
// CHECK-SAME: indexing_maps = [#{{.+}}, #{{.+}}, #[[INIT_MAP]]]
637+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
638+
// CHECK-SAME: outs(%[[ITER_ARG0]] :
639+
// CHECK: scf.yield %[[PARTIAL_RESULT]]
640+
// CHECK scf.yield %[[RESULT0]]
641+
// CHECK: %[[REDUCE:.+]] = linalg.reduce
642+
// CHECK-SAME: ins(%[[RESULT]] :
643+
// CHECK-SAME: outs(%[[INIT]] :
644+
// CHECK-SAME: dimensions = [1, 2]
645+
// CHECK: return %[[REDUCE]]

0 commit comments

Comments
 (0)
Please sign in to comment.