Skip to content

[mlir][MemRef] Add a pattern to simplify `extract_strided_metadata(ca… #68291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,92 @@ class ExtractStridedMetadataOpReinterpretCastFolder
}
};

/// Replace `base, offset, sizes, strides =
/// extract_strided_metadata(
/// cast(src) to dstTy)`
/// With
/// ```
/// base, ... = extract_strided_metadata(src)
/// offset = !dstTy.srcOffset.isDynamic()
/// ? dstTy.srcOffset
/// : extract_strided_metadata(src).offset
/// sizes = for each srcSize in dstTy.srcSizes:
/// !srcSize.isDynamic()
/// ? srcSize
// : extract_strided_metadata(src).sizes[i]
/// strides = for each srcStride in dstTy.srcStrides:
/// !srcStrides.isDynamic()
/// ? srcStrides
/// : extract_strided_metadata(src).strides[i]
/// ```
///
/// In other words, consume the `cast` and apply its effects
/// on the offset, sizes, and strides or compute them directly from `src`.
class ExtractStridedMetadataOpCastFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
PatternRewriter &rewriter) const override {
Value source = extractStridedMetadataOp.getSource();
auto castOp = source.getDefiningOp<memref::CastOp>();
if (!castOp)
return failure();

Location loc = extractStridedMetadataOp.getLoc();
// Check if the source is suitable for extract_strided_metadata.
SmallVector<Type> inferredReturnTypes;
if (failed(extractStridedMetadataOp.inferReturnTypes(
rewriter.getContext(), loc, {castOp.getSource()},
/*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
inferredReturnTypes)))
return rewriter.notifyMatchFailure(castOp,
"cast source's type is incompatible");

auto memrefType = cast<MemRefType>(source.getType());
unsigned rank = memrefType.getRank();
SmallVector<OpFoldResult> results;
results.resize_for_overwrite(rank * 2 + 2);

auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc,
castOp.getSource());

// Register the base_buffer.
results[0] = newExtractStridedMetadata.getBaseBuffer();

auto getConstantOrValue = [&rewriter](int64_t constant,
OpFoldResult ofr) -> OpFoldResult {
return !ShapedType::isDynamic(constant)
? OpFoldResult(rewriter.getIndexAttr(constant))
: ofr;
};

auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
assert(sourceStrides.size() == rank && "unexpected number of strides");

// Register the new offset.
results[1] =
getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());

const unsigned sizeStartIdx = 2;
const unsigned strideStartIdx = sizeStartIdx + rank;
ArrayRef<int64_t> sourceSizes = memrefType.getShape();

SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
for (unsigned i = 0; i < rank; ++i) {
results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
results[strideStartIdx + i] =
getConstantOrValue(sourceStrides[i], strides[i]);
}
rewriter.replaceOp(extractStridedMetadataOp,
getValueOrCreateConstantIndexOp(rewriter, loc, results));
return success();
}
};

/// Replace `base, offset =
/// extract_strided_metadata(extract_strided_metadata(src)#0)`
/// With
Expand Down Expand Up @@ -911,6 +997,7 @@ void memref::populateExpandStridedMetadataPatterns(
ExtractStridedMetadataOpGetGlobalFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
Expand All @@ -923,6 +1010,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
Expand Down
125 changes: 125 additions & 0 deletions mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1369,3 +1369,128 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
memref<i32>, index, index, index, index, index
}

// -----

// Check that we simplify extract_strided_metadata of cast
// when the source of the cast is compatible with what
// `extract_strided_metadata`s accept.
//
// When we apply the transformation the resulting offset, sizes and strides
// should come straight from the inputs of the cast.
// Additionally the folder on extract_strided_metadata should propagate the
// static information.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
//
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
func.func @extract_strided_metadata_of_cast(
%arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<3x?xi32, strided<[4, ?], offset: ?>> to
memref<?x?xi32, strided<[?, ?], offset: ?>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}

// -----

// Check that we simplify extract_strided_metadata of cast
// when the source of the cast is compatible with what
// `extract_strided_metadata`s accept.
//
// Same as extract_strided_metadata_of_cast but with constant sizes and strides
// in the destination type.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
//
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
func.func @extract_strided_metadata_of_cast_w_csts(
%arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<?x?xi32, strided<[?, ?], offset: ?>> to
memref<4x?xi32, strided<[?, 18], offset: 25>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}
// -----

// Check that we don't simplify extract_strided_metadata of
// cast when the source of the cast is unranked.
// Unranked memrefs cannot feed into extract_strided_metadata operations.
// Note: Technically we could still fold the sizes and strides.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
//
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
//
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
func.func @extract_strided_metadata_of_cast_unranked(
%arg : memref<*xi32>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<*xi32> to
memref<?x?xi32, strided<[?, ?], offset: ?>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}