-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][MemRef] Add ExtractStridedMetadataOpCollapseShapeFolder #89954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op.
@llvm/pr-subscribers-mlir-memref Author: Diego Caballero (dcaballe) ChangesThis PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to Full diff: https://github.com/llvm/llvm-project/pull/89954.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 96eb7cfd2db690..b5578a58468e9c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -550,6 +550,78 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
groupStrides)};
}
+
+template <typename ReassociativeReshapeLikeOp,
+ SmallVector<OpFoldResult> (*getReshapedSizes)(
+ ReassociativeReshapeLikeOp, OpBuilder &,
+ ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
+ SmallVector<OpFoldResult> (*getReshapedStrides)(
+ ReassociativeReshapeLikeOp, OpBuilder &,
+ ArrayRef<OpFoldResult> /*origSizes*/,
+ ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
+static FailureOr<StridedMetadata>
+resolveReshapeStridedMetadata(RewriterBase &rewriter,
+ ReassociativeReshapeLikeOp reshape) {
+ // Build a plain extract_strided_metadata(memref) from
+ // extract_strided_metadata(reassociative_reshape_like(memref)).
+ Location origLoc = reshape.getLoc();
+ Value source = reshape.getSrc();
+ auto sourceType = cast<MemRefType>(source.getType());
+ unsigned sourceRank = sourceType.getRank();
+
+ auto newExtractStridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+
+ // Collect statically known information.
+ auto [strides, offset] = getStridesAndOffset(sourceType);
+ MemRefType reshapeType = reshape.getResultType();
+ unsigned reshapeRank = reshapeType.getRank();
+
+ OpFoldResult offsetOfr =
+ ShapedType::isDynamic(offset)
+ ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+ : rewriter.getIndexAttr(offset);
+
+ // Get the special case of 0-D out of the way.
+ if (sourceRank == 0) {
+ SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
+ return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+ /*sizes=*/ones, /*strides=*/ones};
+ }
+
+ SmallVector<OpFoldResult> finalSizes;
+ finalSizes.reserve(reshapeRank);
+ SmallVector<OpFoldResult> finalStrides;
+ finalStrides.reserve(reshapeRank);
+
+ // Compute the reshaped strides and sizes from the base strides and sizes.
+ SmallVector<OpFoldResult> origSizes =
+ getAsOpFoldResult(newExtractStridedMetadata.getSizes());
+ SmallVector<OpFoldResult> origStrides =
+ getAsOpFoldResult(newExtractStridedMetadata.getStrides());
+ unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
+ for (; idx != endIdx; ++idx) {
+ SmallVector<OpFoldResult> reshapedSizes =
+ getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
+ SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
+ reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
+
+ unsigned groupSize = reshapedSizes.size();
+ for (unsigned i = 0; i < groupSize; ++i) {
+ finalSizes.push_back(reshapedSizes[i]);
+ finalStrides.push_back(reshapedStrides[i]);
+ }
+ }
+ assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
+ (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
+ "We should have visited all the input dimensions");
+ assert(finalSizes.size() == reshapeRank &&
+ "We should have populated all the values");
+
+ return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+ finalSizes, finalStrides};
+}
+
/// Replace `baseBuffer, offset, sizes, strides =
/// extract_strided_metadata(reshapeLike(memref))`
/// With
@@ -580,68 +652,66 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
PatternRewriter &rewriter) const override {
- // Build a plain extract_strided_metadata(memref) from
- // extract_strided_metadata(reassociative_reshape_like(memref)).
- Location origLoc = reshape.getLoc();
- Value source = reshape.getSrc();
- auto sourceType = cast<MemRefType>(source.getType());
- unsigned sourceRank = sourceType.getRank();
-
- auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
-
- // Collect statically known information.
- auto [strides, offset] = getStridesAndOffset(sourceType);
- MemRefType reshapeType = reshape.getResultType();
- unsigned reshapeRank = reshapeType.getRank();
-
- OpFoldResult offsetOfr =
- ShapedType::isDynamic(offset)
- ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
- : rewriter.getIndexAttr(offset);
-
- // Get the special case of 0-D out of the way.
- if (sourceRank == 0) {
- SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
- auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
- origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
- offsetOfr, /*sizes=*/ones, /*strides=*/ones);
- rewriter.replaceOp(reshape, memrefDesc.getResult());
- return success();
+ FailureOr<StridedMetadata> stridedMetadata =
+ resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp,
+ getReshapedSizes, getReshapedStrides>(
+ rewriter, reshape);
+ if (failed(stridedMetadata)) {
+ return rewriter.notifyMatchFailure(reshape,
+ "failed to resolve reshape metadata");
}
- SmallVector<OpFoldResult> finalSizes;
- finalSizes.reserve(reshapeRank);
- SmallVector<OpFoldResult> finalStrides;
- finalStrides.reserve(reshapeRank);
-
- // Compute the reshaped strides and sizes from the base strides and sizes.
- SmallVector<OpFoldResult> origSizes =
- getAsOpFoldResult(newExtractStridedMetadata.getSizes());
- SmallVector<OpFoldResult> origStrides =
- getAsOpFoldResult(newExtractStridedMetadata.getStrides());
- unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
- for (; idx != endIdx; ++idx) {
- SmallVector<OpFoldResult> reshapedSizes =
- getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
- SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
- reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
-
- unsigned groupSize = reshapedSizes.size();
- for (unsigned i = 0; i < groupSize; ++i) {
- finalSizes.push_back(reshapedSizes[i]);
- finalStrides.push_back(reshapedStrides[i]);
- }
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ reshape, reshape.getType(), stridedMetadata->basePtr,
+ stridedMetadata->offset, stridedMetadata->sizes,
+ stridedMetadata->strides);
+ return success();
+ }
+};
+
+/// Pattern to replace `extract_strided_metadata(collapse_shape)`
+/// With
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+/// extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subSizes#i
+/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
+/// sizes = subSizes
+/// \verbatim
+///
+/// with `baseBuffer`, `offset`, `sizes` and `strides` being
+/// the replacements for the original `extract_strided_metadata`.
+struct ExtractStridedMetadataOpCollapseShapeFolder
+ : OpRewritePattern<memref::ExtractStridedMetadataOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+ PatternRewriter &rewriter) const override {
+ auto collapseShapeOp =
+ op.getSource().getDefiningOp<memref::CollapseShapeOp>();
+ if (!collapseShapeOp)
+ return failure();
+
+ FailureOr<StridedMetadata> stridedMetadata =
+ resolveReshapeStridedMetadata<memref::CollapseShapeOp, getCollapsedSize,
+ getCollapsedStride>(rewriter,
+ collapseShapeOp);
+ if (failed(stridedMetadata)) {
+ return rewriter.notifyMatchFailure(
+ op, "failed to resolve metadata in terms of source collapse_shape op");
}
- assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
- (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
- "We should have visited all the input dimensions");
- assert(finalSizes.size() == reshapeRank &&
- "We should have populated all the values");
- auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
- origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
- offsetOfr, finalSizes, finalStrides);
- rewriter.replaceOp(reshape, memrefDesc.getResult());
+
+ Location loc = collapseShapeOp.getLoc();
+ SmallVector<Value> results;
+ results.push_back(stridedMetadata->basePtr);
+ results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
+ stridedMetadata->offset));
+ results.append(
+ getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
+ results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
+ stridedMetadata->strides));
+ rewriter.replaceOp(op, results);
return success();
}
};
@@ -1030,6 +1100,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+ ExtractStridedMetadataOpCollapseShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
|
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to Full diff: https://github.com/llvm/llvm-project/pull/89954.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 96eb7cfd2db690..b5578a58468e9c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -550,6 +550,78 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
groupStrides)};
}
+
+template <typename ReassociativeReshapeLikeOp,
+ SmallVector<OpFoldResult> (*getReshapedSizes)(
+ ReassociativeReshapeLikeOp, OpBuilder &,
+ ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
+ SmallVector<OpFoldResult> (*getReshapedStrides)(
+ ReassociativeReshapeLikeOp, OpBuilder &,
+ ArrayRef<OpFoldResult> /*origSizes*/,
+ ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
+static FailureOr<StridedMetadata>
+resolveReshapeStridedMetadata(RewriterBase &rewriter,
+ ReassociativeReshapeLikeOp reshape) {
+ // Build a plain extract_strided_metadata(memref) from
+ // extract_strided_metadata(reassociative_reshape_like(memref)).
+ Location origLoc = reshape.getLoc();
+ Value source = reshape.getSrc();
+ auto sourceType = cast<MemRefType>(source.getType());
+ unsigned sourceRank = sourceType.getRank();
+
+ auto newExtractStridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+
+ // Collect statically known information.
+ auto [strides, offset] = getStridesAndOffset(sourceType);
+ MemRefType reshapeType = reshape.getResultType();
+ unsigned reshapeRank = reshapeType.getRank();
+
+ OpFoldResult offsetOfr =
+ ShapedType::isDynamic(offset)
+ ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+ : rewriter.getIndexAttr(offset);
+
+ // Get the special case of 0-D out of the way.
+ if (sourceRank == 0) {
+ SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
+ return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+ /*sizes=*/ones, /*strides=*/ones};
+ }
+
+ SmallVector<OpFoldResult> finalSizes;
+ finalSizes.reserve(reshapeRank);
+ SmallVector<OpFoldResult> finalStrides;
+ finalStrides.reserve(reshapeRank);
+
+ // Compute the reshaped strides and sizes from the base strides and sizes.
+ SmallVector<OpFoldResult> origSizes =
+ getAsOpFoldResult(newExtractStridedMetadata.getSizes());
+ SmallVector<OpFoldResult> origStrides =
+ getAsOpFoldResult(newExtractStridedMetadata.getStrides());
+ unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
+ for (; idx != endIdx; ++idx) {
+ SmallVector<OpFoldResult> reshapedSizes =
+ getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
+ SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
+ reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
+
+ unsigned groupSize = reshapedSizes.size();
+ for (unsigned i = 0; i < groupSize; ++i) {
+ finalSizes.push_back(reshapedSizes[i]);
+ finalStrides.push_back(reshapedStrides[i]);
+ }
+ }
+ assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
+ (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
+ "We should have visited all the input dimensions");
+ assert(finalSizes.size() == reshapeRank &&
+ "We should have populated all the values");
+
+ return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+ finalSizes, finalStrides};
+}
+
/// Replace `baseBuffer, offset, sizes, strides =
/// extract_strided_metadata(reshapeLike(memref))`
/// With
@@ -580,68 +652,66 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
PatternRewriter &rewriter) const override {
- // Build a plain extract_strided_metadata(memref) from
- // extract_strided_metadata(reassociative_reshape_like(memref)).
- Location origLoc = reshape.getLoc();
- Value source = reshape.getSrc();
- auto sourceType = cast<MemRefType>(source.getType());
- unsigned sourceRank = sourceType.getRank();
-
- auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
-
- // Collect statically known information.
- auto [strides, offset] = getStridesAndOffset(sourceType);
- MemRefType reshapeType = reshape.getResultType();
- unsigned reshapeRank = reshapeType.getRank();
-
- OpFoldResult offsetOfr =
- ShapedType::isDynamic(offset)
- ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
- : rewriter.getIndexAttr(offset);
-
- // Get the special case of 0-D out of the way.
- if (sourceRank == 0) {
- SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
- auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
- origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
- offsetOfr, /*sizes=*/ones, /*strides=*/ones);
- rewriter.replaceOp(reshape, memrefDesc.getResult());
- return success();
+ FailureOr<StridedMetadata> stridedMetadata =
+ resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp,
+ getReshapedSizes, getReshapedStrides>(
+ rewriter, reshape);
+ if (failed(stridedMetadata)) {
+ return rewriter.notifyMatchFailure(reshape,
+ "failed to resolve reshape metadata");
}
- SmallVector<OpFoldResult> finalSizes;
- finalSizes.reserve(reshapeRank);
- SmallVector<OpFoldResult> finalStrides;
- finalStrides.reserve(reshapeRank);
-
- // Compute the reshaped strides and sizes from the base strides and sizes.
- SmallVector<OpFoldResult> origSizes =
- getAsOpFoldResult(newExtractStridedMetadata.getSizes());
- SmallVector<OpFoldResult> origStrides =
- getAsOpFoldResult(newExtractStridedMetadata.getStrides());
- unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
- for (; idx != endIdx; ++idx) {
- SmallVector<OpFoldResult> reshapedSizes =
- getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
- SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
- reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
-
- unsigned groupSize = reshapedSizes.size();
- for (unsigned i = 0; i < groupSize; ++i) {
- finalSizes.push_back(reshapedSizes[i]);
- finalStrides.push_back(reshapedStrides[i]);
- }
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ reshape, reshape.getType(), stridedMetadata->basePtr,
+ stridedMetadata->offset, stridedMetadata->sizes,
+ stridedMetadata->strides);
+ return success();
+ }
+};
+
+/// Pattern to replace `extract_strided_metadata(collapse_shape)`
+/// With
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+/// extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subSizes#i
+/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
+/// sizes = subSizes
+/// \verbatim
+///
+/// with `baseBuffer`, `offset`, `sizes` and `strides` being
+/// the replacements for the original `extract_strided_metadata`.
+struct ExtractStridedMetadataOpCollapseShapeFolder
+ : OpRewritePattern<memref::ExtractStridedMetadataOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+ PatternRewriter &rewriter) const override {
+ auto collapseShapeOp =
+ op.getSource().getDefiningOp<memref::CollapseShapeOp>();
+ if (!collapseShapeOp)
+ return failure();
+
+ FailureOr<StridedMetadata> stridedMetadata =
+ resolveReshapeStridedMetadata<memref::CollapseShapeOp, getCollapsedSize,
+ getCollapsedStride>(rewriter,
+ collapseShapeOp);
+ if (failed(stridedMetadata)) {
+ return rewriter.notifyMatchFailure(
+ op, "failed to resolve metadata in terms of source collapse_shape op");
}
- assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
- (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
- "We should have visited all the input dimensions");
- assert(finalSizes.size() == reshapeRank &&
- "We should have populated all the values");
- auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
- origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
- offsetOfr, finalSizes, finalStrides);
- rewriter.replaceOp(reshape, memrefDesc.getResult());
+
+ Location loc = collapseShapeOp.getLoc();
+ SmallVector<Value> results;
+ results.push_back(stridedMetadata->basePtr);
+ results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
+ stridedMetadata->offset));
+ results.append(
+ getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
+ results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
+ stridedMetadata->strides));
+ rewriter.replaceOp(op, results);
return success();
}
};
@@ -1030,6 +1100,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+ ExtractStridedMetadataOpCollapseShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
|
Interesting enough, I can't find where |
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
@@ -550,6 +550,78 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, | |||
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap, | |||
groupStrides)}; | |||
} | |||
|
|||
template <typename ReassociativeReshapeLikeOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you move the comment explaining what this does?
I guess you meant I see two paths forward
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo:
- add a test
- move some comments on the refactored helper function
SmallVector<OpFoldResult> (*getReshapedSizes)( | ||
ReassociativeReshapeLikeOp, OpBuilder &, | ||
ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/), | ||
SmallVector<OpFoldResult> (*getReshapedStrides)( | ||
ReassociativeReshapeLikeOp, OpBuilder &, | ||
ArrayRef<OpFoldResult> /*origSizes*/, | ||
ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A fly-by nit: is there an observable benefit to using these template arguments as opposed to passing in function_ref
callbacks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not. I mostly followed the style of the caller
Thanks for the suggesting. Ok, let me try 1. |
- Add test - Add doc - Use function_ref
…89954) This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op.
This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to
ExtractStridedMetadataOpSubviewFolder
, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op.