diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 34a94e6ea7051..3dc7d38440ca5 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -137,12 +137,14 @@ void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, void populateVectorTransferFullPartialPatterns( RewritePatternSet &patterns, const VectorTransformsOptions &options); -/// Collect a set of patterns to reduce the rank of the operands of vector -/// transfer ops to operate on the largest contigious vector. -/// These patterns are useful when lowering to dialects with 1d vector type -/// such as llvm and it will result fewer memory reads. -void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( - RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Collect a set of patterns to collapse the most inner unit dims in xfer Ops +/// +/// These patters reduce the rank of the operands of vector transfer ops to +/// operate on vectors without trailing unit dims. This helps reduce the rank of +/// the operands, which can be helpful when lowering to dialects that only +/// support 1D vector type such as LLVM. +void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Patterns that remove redundant Vector Ops by re-ordering them with /// e.g. elementwise Ops: diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 36fc55f3f311d..bcaea1c79471f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2266,11 +2266,6 @@ void mlir::vector::populateVectorMaskMaterializationPatterns( void mlir::vector::populateDropUnitDimWithShapeCastPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - // TODO: Consider either: - // * including DropInnerMostUnitDimsTransferRead and - // DropInnerMostUnitDimsTransferWrite, or - // * better naming to distinguish this and - // populateVectorTransferCollapseInnerMostContiguousDimsPatterns. patterns.add(patterns.getContext(), benefit); } @@ -2305,9 +2300,8 @@ void mlir::vector::populateVectorReductionToContractPatterns( patterns.getContext(), benefit); } -void mlir::vector:: - populateVectorTransferCollapseInnerMostContiguousDimsPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { +void mlir::vector::populateDropInnerMostUnitDimsXferOpPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 54aa96ba89a00..026dda46ecdc4 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -368,7 +368,7 @@ struct TestVectorTransferCollapseInnerMostContiguousDims void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); + populateDropInnerMostUnitDimsXferOpPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } };