diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 96eb7cfd2db69..585c5b7381421 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -550,6 +550,89 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap, groupStrides)}; } + +/// From `reshape_like(memref, subSizes, subStrides))` compute +/// +/// \verbatim +/// baseBuffer, baseOffset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// strides#i = baseStrides#i * subStrides#i +/// sizes = subSizes +/// \endverbatim +/// +/// and return {baseBuffer, baseOffset, sizes, strides} +template +static FailureOr resolveReshapeStridedMetadata( + RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape, + function_ref( + ReassociativeReshapeLikeOp, OpBuilder &, + ArrayRef /*origSizes*/, unsigned /*groupId*/)> + getReshapedSizes, + function_ref( + ReassociativeReshapeLikeOp, OpBuilder &, + ArrayRef /*origSizes*/, + ArrayRef /*origStrides*/, unsigned /*groupId*/)> + getReshapedStrides) { + // Build a plain extract_strided_metadata(memref) from + // extract_strided_metadata(reassociative_reshape_like(memref)). + Location origLoc = reshape.getLoc(); + Value source = reshape.getSrc(); + auto sourceType = cast(source.getType()); + unsigned sourceRank = sourceType.getRank(); + + auto newExtractStridedMetadata = + rewriter.create(origLoc, source); + + // Collect statically known information. + auto [strides, offset] = getStridesAndOffset(sourceType); + MemRefType reshapeType = reshape.getResultType(); + unsigned reshapeRank = reshapeType.getRank(); + + OpFoldResult offsetOfr = + ShapedType::isDynamic(offset) + ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) + : rewriter.getIndexAttr(offset); + + // Get the special case of 0-D out of the way. + if (sourceRank == 0) { + SmallVector ones(reshapeRank, rewriter.getIndexAttr(1)); + return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, + /*sizes=*/ones, /*strides=*/ones}; + } + + SmallVector finalSizes; + finalSizes.reserve(reshapeRank); + SmallVector finalStrides; + finalStrides.reserve(reshapeRank); + + // Compute the reshaped strides and sizes from the base strides and sizes. + SmallVector origSizes = + getAsOpFoldResult(newExtractStridedMetadata.getSizes()); + SmallVector origStrides = + getAsOpFoldResult(newExtractStridedMetadata.getStrides()); + unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); + for (; idx != endIdx; ++idx) { + SmallVector reshapedSizes = + getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); + SmallVector reshapedStrides = getReshapedStrides( + reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); + + unsigned groupSize = reshapedSizes.size(); + for (unsigned i = 0; i < groupSize; ++i) { + finalSizes.push_back(reshapedSizes[i]); + finalStrides.push_back(reshapedStrides[i]); + } + } + assert(((isa(reshape) && idx == sourceRank) || + (isa(reshape) && idx == reshapeRank)) && + "We should have visited all the input dimensions"); + assert(finalSizes.size() == reshapeRank && + "We should have populated all the values"); + + return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, + finalSizes, finalStrides}; +} + /// Replace `baseBuffer, offset, sizes, strides = /// extract_strided_metadata(reshapeLike(memref))` /// With @@ -580,68 +663,65 @@ struct ReshapeFolder : public OpRewritePattern { LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape, PatternRewriter &rewriter) const override { - // Build a plain extract_strided_metadata(memref) from - // extract_strided_metadata(reassociative_reshape_like(memref)). - Location origLoc = reshape.getLoc(); - Value source = reshape.getSrc(); - auto sourceType = cast(source.getType()); - unsigned sourceRank = sourceType.getRank(); - - auto newExtractStridedMetadata = - rewriter.create(origLoc, source); - - // Collect statically known information. - auto [strides, offset] = getStridesAndOffset(sourceType); - MemRefType reshapeType = reshape.getResultType(); - unsigned reshapeRank = reshapeType.getRank(); - - OpFoldResult offsetOfr = - ShapedType::isDynamic(offset) - ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) - : rewriter.getIndexAttr(offset); - - // Get the special case of 0-D out of the way. - if (sourceRank == 0) { - SmallVector ones(reshapeRank, rewriter.getIndexAttr(1)); - auto memrefDesc = rewriter.create( - origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), - offsetOfr, /*sizes=*/ones, /*strides=*/ones); - rewriter.replaceOp(reshape, memrefDesc.getResult()); - return success(); + FailureOr stridedMetadata = + resolveReshapeStridedMetadata( + rewriter, reshape, getReshapedSizes, getReshapedStrides); + if (failed(stridedMetadata)) { + return rewriter.notifyMatchFailure(reshape, + "failed to resolve reshape metadata"); } - SmallVector finalSizes; - finalSizes.reserve(reshapeRank); - SmallVector finalStrides; - finalStrides.reserve(reshapeRank); - - // Compute the reshaped strides and sizes from the base strides and sizes. - SmallVector origSizes = - getAsOpFoldResult(newExtractStridedMetadata.getSizes()); - SmallVector origStrides = - getAsOpFoldResult(newExtractStridedMetadata.getStrides()); - unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); - for (; idx != endIdx; ++idx) { - SmallVector reshapedSizes = - getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); - SmallVector reshapedStrides = getReshapedStrides( - reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); - - unsigned groupSize = reshapedSizes.size(); - for (unsigned i = 0; i < groupSize; ++i) { - finalSizes.push_back(reshapedSizes[i]); - finalStrides.push_back(reshapedStrides[i]); - } + rewriter.replaceOpWithNewOp( + reshape, reshape.getType(), stridedMetadata->basePtr, + stridedMetadata->offset, stridedMetadata->sizes, + stridedMetadata->strides); + return success(); + } +}; + +/// Pattern to replace `extract_strided_metadata(collapse_shape)` +/// With +/// +/// \verbatim +/// baseBuffer, baseOffset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// strides#i = baseStrides#i * subSizes#i +/// offset = baseOffset + sum(subOffset#i * baseStrides#i) +/// sizes = subSizes +/// \verbatim +/// +/// with `baseBuffer`, `offset`, `sizes` and `strides` being +/// the replacements for the original `extract_strided_metadata`. +struct ExtractStridedMetadataOpCollapseShapeFolder + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto collapseShapeOp = + op.getSource().getDefiningOp(); + if (!collapseShapeOp) + return failure(); + + FailureOr stridedMetadata = + resolveReshapeStridedMetadata( + rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride); + if (failed(stridedMetadata)) { + return rewriter.notifyMatchFailure( + op, + "failed to resolve metadata in terms of source collapse_shape op"); } - assert(((isa(reshape) && idx == sourceRank) || - (isa(reshape) && idx == reshapeRank)) && - "We should have visited all the input dimensions"); - assert(finalSizes.size() == reshapeRank && - "We should have populated all the values"); - auto memrefDesc = rewriter.create( - origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), - offsetOfr, finalSizes, finalStrides); - rewriter.replaceOp(reshape, memrefDesc.getResult()); + + Location loc = collapseShapeOp.getLoc(); + SmallVector results; + results.push_back(stridedMetadata->basePtr); + results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, + stridedMetadata->offset)); + results.append( + getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); + results.append(getValueOrCreateConstantIndexOp(rewriter, loc, + stridedMetadata->strides)); + rewriter.replaceOp(op, results); return success(); } }; @@ -1018,9 +1098,11 @@ void memref::populateExpandStridedMetadataPatterns( getCollapsedStride>, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpAllocFolder, + ExtractStridedMetadataOpCollapseShapeFolder, ExtractStridedMetadataOpGetGlobalFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, + ExtractStridedMetadataOpSubviewFolder, ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); @@ -1030,6 +1112,7 @@ void memref::populateResolveExtractStridedMetadataPatterns( RewritePatternSet &patterns) { patterns.add, ExtractStridedMetadataOpAllocFolder, + ExtractStridedMetadataOpCollapseShapeFolder, ExtractStridedMetadataOpGetGlobalFolder, ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 28b7004300594..0705b30ca45d8 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1513,4 +1513,26 @@ func.func @zero_sized_memred(%arg0: f32) -> (memref, index,index,index) %sizes, %strides : memref, index, index, index -} \ No newline at end of file +} + +// ----- + +func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>) + -> (memref, index, index, index) { + + %collapse = memref.collapse_shape %base[[0, 1]] : + memref<5x4xf32> into memref<20xf32> + + %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %collapse : + memref<20xf32> -> memref, index, index, index + + return %base_buffer, %offset, %size, %stride : + memref, index, index, index +} + +// CHECK-LABEL: func @extract_strided_metadata_of_collapse_shape +// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index +// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index +// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata +// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref, index, index, index