From aaf3a95d0e6c52f7ced2d054ca35adb52b15a099 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 24 Apr 2024 17:07:46 +0000 Subject: [PATCH 1/4] [mlir][MemRef] Add ExtractStridedMetadataOpCollapseShapeFolder This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op. --- .../Transforms/ExpandStridedMetadata.cpp | 189 ++++++++++++------ 1 file changed, 130 insertions(+), 59 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 96eb7cfd2db69..b5578a58468e9 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -550,6 +550,78 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap, groupStrides)}; } + +template (*getReshapedSizes)( + ReassociativeReshapeLikeOp, OpBuilder &, + ArrayRef /*origSizes*/, unsigned /*groupId*/), + SmallVector (*getReshapedStrides)( + ReassociativeReshapeLikeOp, OpBuilder &, + ArrayRef /*origSizes*/, + ArrayRef /*origStrides*/, unsigned /*groupId*/)> +static FailureOr +resolveReshapeStridedMetadata(RewriterBase &rewriter, + ReassociativeReshapeLikeOp reshape) { + // Build a plain extract_strided_metadata(memref) from + // extract_strided_metadata(reassociative_reshape_like(memref)). + Location origLoc = reshape.getLoc(); + Value source = reshape.getSrc(); + auto sourceType = cast(source.getType()); + unsigned sourceRank = sourceType.getRank(); + + auto newExtractStridedMetadata = + rewriter.create(origLoc, source); + + // Collect statically known information. + auto [strides, offset] = getStridesAndOffset(sourceType); + MemRefType reshapeType = reshape.getResultType(); + unsigned reshapeRank = reshapeType.getRank(); + + OpFoldResult offsetOfr = + ShapedType::isDynamic(offset) + ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) + : rewriter.getIndexAttr(offset); + + // Get the special case of 0-D out of the way. + if (sourceRank == 0) { + SmallVector ones(reshapeRank, rewriter.getIndexAttr(1)); + return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, + /*sizes=*/ones, /*strides=*/ones}; + } + + SmallVector finalSizes; + finalSizes.reserve(reshapeRank); + SmallVector finalStrides; + finalStrides.reserve(reshapeRank); + + // Compute the reshaped strides and sizes from the base strides and sizes. + SmallVector origSizes = + getAsOpFoldResult(newExtractStridedMetadata.getSizes()); + SmallVector origStrides = + getAsOpFoldResult(newExtractStridedMetadata.getStrides()); + unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); + for (; idx != endIdx; ++idx) { + SmallVector reshapedSizes = + getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); + SmallVector reshapedStrides = getReshapedStrides( + reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); + + unsigned groupSize = reshapedSizes.size(); + for (unsigned i = 0; i < groupSize; ++i) { + finalSizes.push_back(reshapedSizes[i]); + finalStrides.push_back(reshapedStrides[i]); + } + } + assert(((isa(reshape) && idx == sourceRank) || + (isa(reshape) && idx == reshapeRank)) && + "We should have visited all the input dimensions"); + assert(finalSizes.size() == reshapeRank && + "We should have populated all the values"); + + return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, + finalSizes, finalStrides}; +} + /// Replace `baseBuffer, offset, sizes, strides = /// extract_strided_metadata(reshapeLike(memref))` /// With @@ -580,68 +652,66 @@ struct ReshapeFolder : public OpRewritePattern { LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape, PatternRewriter &rewriter) const override { - // Build a plain extract_strided_metadata(memref) from - // extract_strided_metadata(reassociative_reshape_like(memref)). - Location origLoc = reshape.getLoc(); - Value source = reshape.getSrc(); - auto sourceType = cast(source.getType()); - unsigned sourceRank = sourceType.getRank(); - - auto newExtractStridedMetadata = - rewriter.create(origLoc, source); - - // Collect statically known information. - auto [strides, offset] = getStridesAndOffset(sourceType); - MemRefType reshapeType = reshape.getResultType(); - unsigned reshapeRank = reshapeType.getRank(); - - OpFoldResult offsetOfr = - ShapedType::isDynamic(offset) - ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) - : rewriter.getIndexAttr(offset); - - // Get the special case of 0-D out of the way. - if (sourceRank == 0) { - SmallVector ones(reshapeRank, rewriter.getIndexAttr(1)); - auto memrefDesc = rewriter.create( - origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), - offsetOfr, /*sizes=*/ones, /*strides=*/ones); - rewriter.replaceOp(reshape, memrefDesc.getResult()); - return success(); + FailureOr stridedMetadata = + resolveReshapeStridedMetadata( + rewriter, reshape); + if (failed(stridedMetadata)) { + return rewriter.notifyMatchFailure(reshape, + "failed to resolve reshape metadata"); } - SmallVector finalSizes; - finalSizes.reserve(reshapeRank); - SmallVector finalStrides; - finalStrides.reserve(reshapeRank); - - // Compute the reshaped strides and sizes from the base strides and sizes. - SmallVector origSizes = - getAsOpFoldResult(newExtractStridedMetadata.getSizes()); - SmallVector origStrides = - getAsOpFoldResult(newExtractStridedMetadata.getStrides()); - unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); - for (; idx != endIdx; ++idx) { - SmallVector reshapedSizes = - getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); - SmallVector reshapedStrides = getReshapedStrides( - reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); - - unsigned groupSize = reshapedSizes.size(); - for (unsigned i = 0; i < groupSize; ++i) { - finalSizes.push_back(reshapedSizes[i]); - finalStrides.push_back(reshapedStrides[i]); - } + rewriter.replaceOpWithNewOp( + reshape, reshape.getType(), stridedMetadata->basePtr, + stridedMetadata->offset, stridedMetadata->sizes, + stridedMetadata->strides); + return success(); + } +}; + +/// Pattern to replace `extract_strided_metadata(collapse_shape)` +/// With +/// +/// \verbatim +/// baseBuffer, baseOffset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// strides#i = baseStrides#i * subSizes#i +/// offset = baseOffset + sum(subOffset#i * baseStrides#i) +/// sizes = subSizes +/// \verbatim +/// +/// with `baseBuffer`, `offset`, `sizes` and `strides` being +/// the replacements for the original `extract_strided_metadata`. +struct ExtractStridedMetadataOpCollapseShapeFolder + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto collapseShapeOp = + op.getSource().getDefiningOp(); + if (!collapseShapeOp) + return failure(); + + FailureOr stridedMetadata = + resolveReshapeStridedMetadata(rewriter, + collapseShapeOp); + if (failed(stridedMetadata)) { + return rewriter.notifyMatchFailure( + op, "failed to resolve metadata in terms of source collapse_shape op"); } - assert(((isa(reshape) && idx == sourceRank) || - (isa(reshape) && idx == reshapeRank)) && - "We should have visited all the input dimensions"); - assert(finalSizes.size() == reshapeRank && - "We should have populated all the values"); - auto memrefDesc = rewriter.create( - origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), - offsetOfr, finalSizes, finalStrides); - rewriter.replaceOp(reshape, memrefDesc.getResult()); + + Location loc = collapseShapeOp.getLoc(); + SmallVector results; + results.push_back(stridedMetadata->basePtr); + results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, + stridedMetadata->offset)); + results.append( + getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); + results.append(getValueOrCreateConstantIndexOp(rewriter, loc, + stridedMetadata->strides)); + rewriter.replaceOp(op, results); return success(); } }; @@ -1030,6 +1100,7 @@ void memref::populateResolveExtractStridedMetadataPatterns( RewritePatternSet &patterns) { patterns.add, ExtractStridedMetadataOpAllocFolder, + ExtractStridedMetadataOpCollapseShapeFolder, ExtractStridedMetadataOpGetGlobalFolder, ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, From 19c46b519f9a64e92ad9638d04e2a0ba528b96bf Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Fri, 26 Apr 2024 09:55:40 +0000 Subject: [PATCH 2/4] Review feedback - Add test - Add doc - Use function_ref --- .../Transforms/ExpandStridedMetadata.cpp | 38 +++++++++++++------ .../MemRef/expand-strided-metadata.mlir | 24 +++++++++++- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index b5578a58468e9..479646756cb5d 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -551,17 +551,28 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, groupStrides)}; } -template (*getReshapedSizes)( - ReassociativeReshapeLikeOp, OpBuilder &, - ArrayRef /*origSizes*/, unsigned /*groupId*/), - SmallVector (*getReshapedStrides)( - ReassociativeReshapeLikeOp, OpBuilder &, - ArrayRef /*origSizes*/, - ArrayRef /*origStrides*/, unsigned /*groupId*/)> -static FailureOr -resolveReshapeStridedMetadata(RewriterBase &rewriter, - ReassociativeReshapeLikeOp reshape) { +/// From `reshape_like(memref, subSizes, subStrides))` compute +/// +/// \verbatim +/// baseBuffer, baseOffset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// strides#i = baseStrides#i * subStrides#i +/// offset = baseOffset + sum(subOffset#i * baseStrides#i) +/// sizes = subSizes +/// \endverbatim +/// +/// and return {baseBuffer, offset, sizes, strides} +template +static FailureOr resolveReshapeStridedMetadata( + RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape, + function_ref( + ReassociativeReshapeLikeOp, OpBuilder &, + ArrayRef /*origSizes*/, unsigned /*groupId*/)> + getReshapedSizes, + function_ref( + ReassociativeReshapeLikeOp, OpBuilder &, + ArrayRef /*origSizes*/, unsigned /*groupId*/)> + getReshapedStrides) { // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = reshape.getLoc(); @@ -699,7 +710,8 @@ struct ExtractStridedMetadataOpCollapseShapeFolder collapseShapeOp); if (failed(stridedMetadata)) { return rewriter.notifyMatchFailure( - op, "failed to resolve metadata in terms of source collapse_shape op"); + op, + "failed to resolve metadata in terms of source collapse_shape op"); } Location loc = collapseShapeOp.getLoc(); @@ -1088,9 +1100,11 @@ void memref::populateExpandStridedMetadataPatterns( getCollapsedStride>, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpAllocFolder, + ExtractStridedMetadataOpCollapseShapeFolder, ExtractStridedMetadataOpGetGlobalFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, + ExtractStridedMetadataOpSubviewFolder, ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 28b7004300594..0705b30ca45d8 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1513,4 +1513,26 @@ func.func @zero_sized_memred(%arg0: f32) -> (memref, index,index,index) %sizes, %strides : memref, index, index, index -} \ No newline at end of file +} + +// ----- + +func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>) + -> (memref, index, index, index) { + + %collapse = memref.collapse_shape %base[[0, 1]] : + memref<5x4xf32> into memref<20xf32> + + %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %collapse : + memref<20xf32> -> memref, index, index, index + + return %base_buffer, %offset, %size, %stride : + memref, index, index, index +} + +// CHECK-LABEL: func @extract_strided_metadata_of_collapse_shape +// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index +// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index +// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata +// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref, index, index, index From 3500630fc7a90164dd5180e4696fbd901b083545 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Fri, 26 Apr 2024 10:00:15 +0000 Subject: [PATCH 3/4] Fix doc --- mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 479646756cb5d..999b50e25ca8f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -557,11 +557,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subStrides#i -/// offset = baseOffset + sum(subOffset#i * baseStrides#i) /// sizes = subSizes /// \endverbatim /// -/// and return {baseBuffer, offset, sizes, strides} +/// and return {baseBuffer, baseOffset, sizes, strides} template static FailureOr resolveReshapeStridedMetadata( RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape, From e3bd384523e6498acd437882f2f87db8b7db3604 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Fri, 26 Apr 2024 10:19:44 +0000 Subject: [PATCH 4/4] Final fixups! --- .../MemRef/Transforms/ExpandStridedMetadata.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 999b50e25ca8f..585c5b7381421 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -570,7 +570,8 @@ static FailureOr resolveReshapeStridedMetadata( getReshapedSizes, function_ref( ReassociativeReshapeLikeOp, OpBuilder &, - ArrayRef /*origSizes*/, unsigned /*groupId*/)> + ArrayRef /*origSizes*/, + ArrayRef /*origStrides*/, unsigned /*groupId*/)> getReshapedStrides) { // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(reassociative_reshape_like(memref)). @@ -663,9 +664,8 @@ struct ReshapeFolder : public OpRewritePattern { LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape, PatternRewriter &rewriter) const override { FailureOr stridedMetadata = - resolveReshapeStridedMetadata( - rewriter, reshape); + resolveReshapeStridedMetadata( + rewriter, reshape, getReshapedSizes, getReshapedStrides); if (failed(stridedMetadata)) { return rewriter.notifyMatchFailure(reshape, "failed to resolve reshape metadata"); @@ -704,9 +704,8 @@ struct ExtractStridedMetadataOpCollapseShapeFolder return failure(); FailureOr stridedMetadata = - resolveReshapeStridedMetadata(rewriter, - collapseShapeOp); + resolveReshapeStridedMetadata( + rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride); if (failed(stridedMetadata)) { return rewriter.notifyMatchFailure( op,