Skip to content

[mlir][MemRef] Add ExtractStridedMetadataOpCollapseShapeFolder #89954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 26, 2024

Conversation

dcaballe
Copy link
Contributor

This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to ExtractStridedMetadataOpSubviewFolder, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op.

This PR adds a new pattern to the set of patterns used to resolve the
offset, sizes and stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`,
the new pattern resolves strided_metadata(collapse_shape) directly,
without introduce a reshape_cast op.
@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-mlir-memref

Author: Diego Caballero (dcaballe)

Changes

This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to ExtractStridedMetadataOpSubviewFolder, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op.


Full diff: https://github.com/llvm/llvm-project/pull/89954.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+130-59)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 96eb7cfd2db690..b5578a58468e9c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -550,6 +550,78 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
   return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
                                       groupStrides)};
 }
+
+template <typename ReassociativeReshapeLikeOp,
+          SmallVector<OpFoldResult> (*getReshapedSizes)(
+              ReassociativeReshapeLikeOp, OpBuilder &,
+              ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
+          SmallVector<OpFoldResult> (*getReshapedStrides)(
+              ReassociativeReshapeLikeOp, OpBuilder &,
+              ArrayRef<OpFoldResult> /*origSizes*/,
+              ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
+static FailureOr<StridedMetadata>
+resolveReshapeStridedMetadata(RewriterBase &rewriter,
+                              ReassociativeReshapeLikeOp reshape) {
+  // Build a plain extract_strided_metadata(memref) from
+  // extract_strided_metadata(reassociative_reshape_like(memref)).
+  Location origLoc = reshape.getLoc();
+  Value source = reshape.getSrc();
+  auto sourceType = cast<MemRefType>(source.getType());
+  unsigned sourceRank = sourceType.getRank();
+
+  auto newExtractStridedMetadata =
+      rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+
+  // Collect statically known information.
+  auto [strides, offset] = getStridesAndOffset(sourceType);
+  MemRefType reshapeType = reshape.getResultType();
+  unsigned reshapeRank = reshapeType.getRank();
+
+  OpFoldResult offsetOfr =
+      ShapedType::isDynamic(offset)
+          ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+          : rewriter.getIndexAttr(offset);
+
+  // Get the special case of 0-D out of the way.
+  if (sourceRank == 0) {
+    SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
+    return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+                           /*sizes=*/ones, /*strides=*/ones};
+  }
+
+  SmallVector<OpFoldResult> finalSizes;
+  finalSizes.reserve(reshapeRank);
+  SmallVector<OpFoldResult> finalStrides;
+  finalStrides.reserve(reshapeRank);
+
+  // Compute the reshaped strides and sizes from the base strides and sizes.
+  SmallVector<OpFoldResult> origSizes =
+      getAsOpFoldResult(newExtractStridedMetadata.getSizes());
+  SmallVector<OpFoldResult> origStrides =
+      getAsOpFoldResult(newExtractStridedMetadata.getStrides());
+  unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
+  for (; idx != endIdx; ++idx) {
+    SmallVector<OpFoldResult> reshapedSizes =
+        getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
+    SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
+        reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
+
+    unsigned groupSize = reshapedSizes.size();
+    for (unsigned i = 0; i < groupSize; ++i) {
+      finalSizes.push_back(reshapedSizes[i]);
+      finalStrides.push_back(reshapedStrides[i]);
+    }
+  }
+  assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
+          (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
+         "We should have visited all the input dimensions");
+  assert(finalSizes.size() == reshapeRank &&
+         "We should have populated all the values");
+
+  return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+                         finalSizes, finalStrides};
+}
+
 /// Replace `baseBuffer, offset, sizes, strides =
 ///              extract_strided_metadata(reshapeLike(memref))`
 /// With
@@ -580,68 +652,66 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
 
   LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
                                 PatternRewriter &rewriter) const override {
-    // Build a plain extract_strided_metadata(memref) from
-    // extract_strided_metadata(reassociative_reshape_like(memref)).
-    Location origLoc = reshape.getLoc();
-    Value source = reshape.getSrc();
-    auto sourceType = cast<MemRefType>(source.getType());
-    unsigned sourceRank = sourceType.getRank();
-
-    auto newExtractStridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
-
-    // Collect statically known information.
-    auto [strides, offset] = getStridesAndOffset(sourceType);
-    MemRefType reshapeType = reshape.getResultType();
-    unsigned reshapeRank = reshapeType.getRank();
-
-    OpFoldResult offsetOfr =
-        ShapedType::isDynamic(offset)
-            ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
-            : rewriter.getIndexAttr(offset);
-
-    // Get the special case of 0-D out of the way.
-    if (sourceRank == 0) {
-      SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
-      auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
-          origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
-          offsetOfr, /*sizes=*/ones, /*strides=*/ones);
-      rewriter.replaceOp(reshape, memrefDesc.getResult());
-      return success();
+    FailureOr<StridedMetadata> stridedMetadata =
+        resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp,
+                                      getReshapedSizes, getReshapedStrides>(
+            rewriter, reshape);
+    if (failed(stridedMetadata)) {
+      return rewriter.notifyMatchFailure(reshape,
+                                         "failed to resolve reshape metadata");
     }
 
-    SmallVector<OpFoldResult> finalSizes;
-    finalSizes.reserve(reshapeRank);
-    SmallVector<OpFoldResult> finalStrides;
-    finalStrides.reserve(reshapeRank);
-
-    // Compute the reshaped strides and sizes from the base strides and sizes.
-    SmallVector<OpFoldResult> origSizes =
-        getAsOpFoldResult(newExtractStridedMetadata.getSizes());
-    SmallVector<OpFoldResult> origStrides =
-        getAsOpFoldResult(newExtractStridedMetadata.getStrides());
-    unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
-    for (; idx != endIdx; ++idx) {
-      SmallVector<OpFoldResult> reshapedSizes =
-          getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
-      SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
-          reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
-
-      unsigned groupSize = reshapedSizes.size();
-      for (unsigned i = 0; i < groupSize; ++i) {
-        finalSizes.push_back(reshapedSizes[i]);
-        finalStrides.push_back(reshapedStrides[i]);
-      }
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        reshape, reshape.getType(), stridedMetadata->basePtr,
+        stridedMetadata->offset, stridedMetadata->sizes,
+        stridedMetadata->strides);
+    return success();
+  }
+};
+
+/// Pattern to replace `extract_strided_metadata(collapse_shape)`
+/// With
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+///     extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subSizes#i
+/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
+/// sizes = subSizes
+/// \verbatim
+///
+/// with `baseBuffer`, `offset`, `sizes` and `strides` being
+/// the replacements for the original `extract_strided_metadata`.
+struct ExtractStridedMetadataOpCollapseShapeFolder
+    : OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto collapseShapeOp =
+        op.getSource().getDefiningOp<memref::CollapseShapeOp>();
+    if (!collapseShapeOp)
+      return failure();
+
+    FailureOr<StridedMetadata> stridedMetadata =
+        resolveReshapeStridedMetadata<memref::CollapseShapeOp, getCollapsedSize,
+                                      getCollapsedStride>(rewriter,
+                                                          collapseShapeOp);
+    if (failed(stridedMetadata)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to resolve metadata in terms of source collapse_shape op");
     }
-    assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
-            (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
-           "We should have visited all the input dimensions");
-    assert(finalSizes.size() == reshapeRank &&
-           "We should have populated all the values");
-    auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
-        origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
-        offsetOfr, finalSizes, finalStrides);
-    rewriter.replaceOp(reshape, memrefDesc.getResult());
+
+    Location loc = collapseShapeOp.getLoc();
+    SmallVector<Value> results;
+    results.push_back(stridedMetadata->basePtr);
+    results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                      stridedMetadata->offset));
+    results.append(
+        getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
+    results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                   stridedMetadata->strides));
+    rewriter.replaceOp(op, results);
     return success();
   }
 };
@@ -1030,6 +1100,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
     RewritePatternSet &patterns) {
   patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
                ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+               ExtractStridedMetadataOpCollapseShapeFolder,
                ExtractStridedMetadataOpGetGlobalFolder,
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,

@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to ExtractStridedMetadataOpSubviewFolder, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op.


Full diff: https://github.com/llvm/llvm-project/pull/89954.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+130-59)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 96eb7cfd2db690..b5578a58468e9c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -550,6 +550,78 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
   return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
                                       groupStrides)};
 }
+
+template <typename ReassociativeReshapeLikeOp,
+          SmallVector<OpFoldResult> (*getReshapedSizes)(
+              ReassociativeReshapeLikeOp, OpBuilder &,
+              ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
+          SmallVector<OpFoldResult> (*getReshapedStrides)(
+              ReassociativeReshapeLikeOp, OpBuilder &,
+              ArrayRef<OpFoldResult> /*origSizes*/,
+              ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
+static FailureOr<StridedMetadata>
+resolveReshapeStridedMetadata(RewriterBase &rewriter,
+                              ReassociativeReshapeLikeOp reshape) {
+  // Build a plain extract_strided_metadata(memref) from
+  // extract_strided_metadata(reassociative_reshape_like(memref)).
+  Location origLoc = reshape.getLoc();
+  Value source = reshape.getSrc();
+  auto sourceType = cast<MemRefType>(source.getType());
+  unsigned sourceRank = sourceType.getRank();
+
+  auto newExtractStridedMetadata =
+      rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+
+  // Collect statically known information.
+  auto [strides, offset] = getStridesAndOffset(sourceType);
+  MemRefType reshapeType = reshape.getResultType();
+  unsigned reshapeRank = reshapeType.getRank();
+
+  OpFoldResult offsetOfr =
+      ShapedType::isDynamic(offset)
+          ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+          : rewriter.getIndexAttr(offset);
+
+  // Get the special case of 0-D out of the way.
+  if (sourceRank == 0) {
+    SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
+    return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+                           /*sizes=*/ones, /*strides=*/ones};
+  }
+
+  SmallVector<OpFoldResult> finalSizes;
+  finalSizes.reserve(reshapeRank);
+  SmallVector<OpFoldResult> finalStrides;
+  finalStrides.reserve(reshapeRank);
+
+  // Compute the reshaped strides and sizes from the base strides and sizes.
+  SmallVector<OpFoldResult> origSizes =
+      getAsOpFoldResult(newExtractStridedMetadata.getSizes());
+  SmallVector<OpFoldResult> origStrides =
+      getAsOpFoldResult(newExtractStridedMetadata.getStrides());
+  unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
+  for (; idx != endIdx; ++idx) {
+    SmallVector<OpFoldResult> reshapedSizes =
+        getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
+    SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
+        reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
+
+    unsigned groupSize = reshapedSizes.size();
+    for (unsigned i = 0; i < groupSize; ++i) {
+      finalSizes.push_back(reshapedSizes[i]);
+      finalStrides.push_back(reshapedStrides[i]);
+    }
+  }
+  assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
+          (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
+         "We should have visited all the input dimensions");
+  assert(finalSizes.size() == reshapeRank &&
+         "We should have populated all the values");
+
+  return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+                         finalSizes, finalStrides};
+}
+
 /// Replace `baseBuffer, offset, sizes, strides =
 ///              extract_strided_metadata(reshapeLike(memref))`
 /// With
@@ -580,68 +652,66 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
 
   LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
                                 PatternRewriter &rewriter) const override {
-    // Build a plain extract_strided_metadata(memref) from
-    // extract_strided_metadata(reassociative_reshape_like(memref)).
-    Location origLoc = reshape.getLoc();
-    Value source = reshape.getSrc();
-    auto sourceType = cast<MemRefType>(source.getType());
-    unsigned sourceRank = sourceType.getRank();
-
-    auto newExtractStridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
-
-    // Collect statically known information.
-    auto [strides, offset] = getStridesAndOffset(sourceType);
-    MemRefType reshapeType = reshape.getResultType();
-    unsigned reshapeRank = reshapeType.getRank();
-
-    OpFoldResult offsetOfr =
-        ShapedType::isDynamic(offset)
-            ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
-            : rewriter.getIndexAttr(offset);
-
-    // Get the special case of 0-D out of the way.
-    if (sourceRank == 0) {
-      SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
-      auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
-          origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
-          offsetOfr, /*sizes=*/ones, /*strides=*/ones);
-      rewriter.replaceOp(reshape, memrefDesc.getResult());
-      return success();
+    FailureOr<StridedMetadata> stridedMetadata =
+        resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp,
+                                      getReshapedSizes, getReshapedStrides>(
+            rewriter, reshape);
+    if (failed(stridedMetadata)) {
+      return rewriter.notifyMatchFailure(reshape,
+                                         "failed to resolve reshape metadata");
     }
 
-    SmallVector<OpFoldResult> finalSizes;
-    finalSizes.reserve(reshapeRank);
-    SmallVector<OpFoldResult> finalStrides;
-    finalStrides.reserve(reshapeRank);
-
-    // Compute the reshaped strides and sizes from the base strides and sizes.
-    SmallVector<OpFoldResult> origSizes =
-        getAsOpFoldResult(newExtractStridedMetadata.getSizes());
-    SmallVector<OpFoldResult> origStrides =
-        getAsOpFoldResult(newExtractStridedMetadata.getStrides());
-    unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
-    for (; idx != endIdx; ++idx) {
-      SmallVector<OpFoldResult> reshapedSizes =
-          getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
-      SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
-          reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
-
-      unsigned groupSize = reshapedSizes.size();
-      for (unsigned i = 0; i < groupSize; ++i) {
-        finalSizes.push_back(reshapedSizes[i]);
-        finalStrides.push_back(reshapedStrides[i]);
-      }
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        reshape, reshape.getType(), stridedMetadata->basePtr,
+        stridedMetadata->offset, stridedMetadata->sizes,
+        stridedMetadata->strides);
+    return success();
+  }
+};
+
+/// Pattern to replace `extract_strided_metadata(collapse_shape)`
+/// With
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+///     extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subSizes#i
+/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
+/// sizes = subSizes
+/// \verbatim
+///
+/// with `baseBuffer`, `offset`, `sizes` and `strides` being
+/// the replacements for the original `extract_strided_metadata`.
+struct ExtractStridedMetadataOpCollapseShapeFolder
+    : OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto collapseShapeOp =
+        op.getSource().getDefiningOp<memref::CollapseShapeOp>();
+    if (!collapseShapeOp)
+      return failure();
+
+    FailureOr<StridedMetadata> stridedMetadata =
+        resolveReshapeStridedMetadata<memref::CollapseShapeOp, getCollapsedSize,
+                                      getCollapsedStride>(rewriter,
+                                                          collapseShapeOp);
+    if (failed(stridedMetadata)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to resolve metadata in terms of source collapse_shape op");
     }
-    assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
-            (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
-           "We should have visited all the input dimensions");
-    assert(finalSizes.size() == reshapeRank &&
-           "We should have populated all the values");
-    auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
-        origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
-        offsetOfr, finalSizes, finalStrides);
-    rewriter.replaceOp(reshape, memrefDesc.getResult());
+
+    Location loc = collapseShapeOp.getLoc();
+    SmallVector<Value> results;
+    results.push_back(stridedMetadata->basePtr);
+    results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                      stridedMetadata->offset));
+    results.append(
+        getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
+    results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                   stridedMetadata->strides));
+    rewriter.replaceOp(op, results);
     return success();
   }
 };
@@ -1030,6 +1100,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
     RewritePatternSet &patterns) {
   patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
                ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+               ExtractStridedMetadataOpCollapseShapeFolder,
                ExtractStridedMetadataOpGetGlobalFolder,
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,

@dcaballe
Copy link
Contributor Author

Interesting enough, I can't find where memref::populateExpandStridedMetadataPatterns is being tested???

Copy link

github-actions bot commented Apr 24, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@qcolombet
Copy link
Collaborator

Interesting enough, I can't find where memref::populateExpandStridedMetadataPatterns is being tested???

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir, no?

@@ -550,6 +550,78 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
groupStrides)};
}

template <typename ReassociativeReshapeLikeOp,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move the comment explaining what this does?

@qcolombet
Copy link
Collaborator

Interesting enough, I can't find where memref::populateExpandStridedMetadataPatterns is being tested???

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir, no?

I guess you meant populateResolveExtractStridedMetadataPatterns since that's the one you're modifying.
I thought that populateResolveExtractStridedMetadataPatterns was a subset of populateExpandStridedMetadataPatterns but looks like it changed with https://reviews.llvm.org/D147393 and I didn't notice.

I see two paths forward

  1. use populateResolveExtractStridedMetadataPatterns inside populateExpandStridedMetadataPatterns so that we can be sure Resolve is a subset of Expand. I think this is true today (modulo what https://reviews.llvm.org/D147393 added, but re-converging remains good IMHO) and we can revisit when they need to diverge.
  2. integrate a test in the narrowing type test pass, since this is the only place where Resolve is used.

Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo:

  1. add a test
  2. move some comments on the refactored helper function

Comment on lines 555 to 561
SmallVector<OpFoldResult> (*getReshapedSizes)(
ReassociativeReshapeLikeOp, OpBuilder &,
ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
SmallVector<OpFoldResult> (*getReshapedStrides)(
ReassociativeReshapeLikeOp, OpBuilder &,
ArrayRef<OpFoldResult> /*origSizes*/,
ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fly-by nit: is there an observable benefit to using these template arguments as opposed to passing in function_ref callbacks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. I mostly followed the style of the caller

@dcaballe
Copy link
Contributor Author

I see two paths forward

  1. use populateResolveExtractStridedMetadataPatterns inside populateExpandStridedMetadataPatterns so that we can be sure Resolve is a subset of Expand. I think this is true today (modulo what reviews.llvm.org/D147393 added, but re-converging remains good IMHO) and we can revisit when they need to diverge.
  2. integrate a test in the narrowing type test pass, since this is the only place where Resolve is used.

Thanks for the suggesting. Ok, let me try 1.

@dcaballe dcaballe merged commit 450ac01 into llvm:main Apr 26, 2024
3 of 4 checks passed
@dcaballe dcaballe deleted the expand-metadata-collapse branch April 26, 2024 14:20
dcaballe added a commit to iree-org/llvm-project that referenced this pull request Apr 26, 2024
…89954)

This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and
stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`, the new
pattern resolves strided_metadata(collapse_shape) directly, without introduce a
reshape_cast op.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants