Skip to content

[mlir][vector][nfc] Rename populateVectorTransferCollapseInnerMostContiguousDimsPatterns #145228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Jun 22, 2025

Renames populateVectorTransferCollapseInnerMostContiguousDimsPatterns
as populateDropInnerMostUnitDimsXferOpPatterns + updates the
corresponding comments.

This addresses a TODO and makes the difference between these two
populate* methods clearer:

  • populateDropUnitDimWithShapeCastPatterns,
  • populateDropInnerMostUnitDimsXferOpPatterns.

@llvmbot
Copy link
Member

llvmbot commented Jun 22, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][linalg][nfc] Split hoisting tests into dedicated test functions
  • Move test 2
  • Move test 3
  • Move test 4
  • Move test 5
  • Remove test 6 that duplicats @negative_xfer_pair_double_loop_mem_use_inside_inner_loop_before_write
  • Add block comments, simplify names
  • [mlir][vector][nfc] Rename populateVectorTransferCollapseInnerMostContiguousDimsPatterns

Full diff: https://github.com/llvm/llvm-project/pull/145228.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+8-6)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+2-8)
  • (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+204-63)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..3dc7d38440ca5 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -137,12 +137,14 @@ void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
 void populateVectorTransferFullPartialPatterns(
     RewritePatternSet &patterns, const VectorTransformsOptions &options);
 
-/// Collect a set of patterns to reduce the rank of the operands of vector
-/// transfer ops to operate on the largest contigious vector.
-/// These patterns are useful when lowering to dialects with 1d vector type
-/// such as llvm and it will result fewer memory reads.
-void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Collect a set of patterns to collapse the most inner unit dims in xfer Ops
+///
+/// These patters reduce the rank of the operands of vector transfer ops to
+/// operate on vectors without trailing unit dims. This helps reduce the rank of
+/// the operands, which can be helpful when lowering to dialects that only
+/// support 1D vector type such as LLVM.
+void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns,
+                                                 PatternBenefit benefit = 1);
 
 /// Patterns that remove redundant Vector Ops by re-ordering them with
 /// e.g. elementwise Ops:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 36fc55f3f311d..bcaea1c79471f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2266,11 +2266,6 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
 
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  // TODO: Consider either:
-  //  * including DropInnerMostUnitDimsTransferRead and
-  //    DropInnerMostUnitDimsTransferWrite, or
-  //  * better naming to distinguish this and
-  //    populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
   patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
                DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
 }
@@ -2305,9 +2300,8 @@ void mlir::vector::populateVectorReductionToContractPatterns(
       patterns.getContext(), benefit);
 }
 
-void mlir::vector::
-    populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
-        RewritePatternSet &patterns, PatternBenefit benefit) {
+void mlir::vector::populateDropInnerMostUnitDimsXferOpPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DropInnerMostUnitDimsTransferRead,
                DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
                                                    benefit);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..461ca79340703 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,76 +1,210 @@
 // RUN: mlir-opt  -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
 
-// CHECK-LABEL: func @hoist_vector_transfer_pairs(
-//  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
-func.func @hoist_vector_transfer_pairs(
-    %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
-    %memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
-    %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
+///----------------------------------------------------------------------------------------
+/// Tests for vector.transfer_read + vector.transfer_write pairs
+///
+/// * Nested in double loops
+//  * Indices depend on induction variables
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func @mem_use_outside
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index)
+func.func @mem_use_outside(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:               %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:               scf.yield %[[USE]] : vector<1xf32>
+// CHECK:             }
+// CHECK:             vector.transfer_write %[[SCF]], %[[MEM]]{{\[}}%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:             "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
+// CHECK:           }
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %read = vector.transfer_read %mem[%i, %i], %pad: memref<?x?xf32>, vector<1xf32>
+      %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref<?x?xf32>
+    }
+  }
+  "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @mem_use_inside_outer_loop
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index)
+func.func @mem_use_inside_outer_loop(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:               %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:               scf.yield %[[USE]] : vector<1xf32>
+// CHECK:             }
+// CHECK:             vector.transfer_write %[[SCF]], %[[MEM]]{{\[}}%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
+// CHECK:           }
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %read = vector.transfer_read %mem[%i, %i], %pad: memref<?x?xf32>, vector<1xf32>
+      %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref<?x?xf32>
+    }
+    "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for vector.transfer_read + vector.transfer_write pairs
+///
+/// * Nested in double loops
+//  * Indices are constant
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_write
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index)
+func.func @negative_mem_use_inside_inner_loop_before_write(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.0 : f32
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:               %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:               %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:               "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
+// CHECK:               vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:             }
+// CHECK:           }
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %read = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
+      "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+      vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
 
-// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
-// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
-// CHECK:   vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
-// CHECK:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
-// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
-// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
-// CHECK:     "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
-// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK:     "some_use"(%[[MEMREF2]], %{{.*}}) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
-// CHECK:     vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
-// CHECK:     vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32>
-// CHECK:     vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32>
-// CHECK:     "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> ()
-// CHECK:     scf.yield {{.*}} : vector<1xf32>, vector<2xf32>
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_after_write
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index)
+func.func @negative_mem_use_inside_inner_loop_after_write(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:               %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:               %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:               vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:               "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
+// CHECK:             }
+// CHECK:           }
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r3 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u3 = "val_use"(%r3) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u3, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+      "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+    }
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_read
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index)
+func.func @negative_mem_use_inside_inner_loop_before_read(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:     "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
+// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
+// CHECK:     "val_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:     vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
 // CHECK:   }
-// CHECK:   vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32>
-// CHECK:   "unrelated_use"(%[[MEMREF0]]) : (memref<?x?xf32>) -> ()
-// CHECK:   scf.yield {{.*}} : vector<1xf32>
 // CHECK: }
-// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
-// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref<?x?xf32>) -> ()
   scf.for %i = %lb to %ub step %step {
     scf.for %j = %lb to %ub step %step {
-      %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
-      %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
-      %r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
-      %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32>
-      "some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
-      %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
-      %r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref<?x?xf32>, vector<6xf32>
-      "some_crippling_use"(%memref5) : (memref<?x?xf32>) -> ()
-      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
-      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
-      %u2 = "some_use"(%memref2, %r2) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
-      %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
-      %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
-      %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
-      vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
-      vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
-      vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
-      vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
-      vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
-      vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref<?x?xf32>
-      "some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
+      "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+      %read = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
     }
-    "unrelated_use"(%memref0) : (memref<?x?xf32>) -> ()
   }
-  "unrelated_use"(%memref1) : (memref<?x?xf32>) -> ()
   return
 }
 
@@ -86,6 +220,13 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// Other tests
+///
+/// TODO: Document
+///----------------------------------------------------------------------------------------
+
+
 // CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint(
 //  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 54aa96ba89a00..026dda46ecdc4 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -368,7 +368,7 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
+    populateDropInnerMostUnitDimsXferOpPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };

@llvmbot
Copy link
Member

llvmbot commented Jun 22, 2025

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][linalg][nfc] Split hoisting tests into dedicated test functions
  • Move test 2
  • Move test 3
  • Move test 4
  • Move test 5
  • Remove test 6 that duplicats @negative_xfer_pair_double_loop_mem_use_inside_inner_loop_before_write
  • Add block comments, simplify names
  • [mlir][vector][nfc] Rename populateVectorTransferCollapseInnerMostContiguousDimsPatterns

Full diff: https://github.com/llvm/llvm-project/pull/145228.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+8-6)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+2-8)
  • (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+204-63)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..3dc7d38440ca5 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -137,12 +137,14 @@ void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
 void populateVectorTransferFullPartialPatterns(
     RewritePatternSet &patterns, const VectorTransformsOptions &options);
 
-/// Collect a set of patterns to reduce the rank of the operands of vector
-/// transfer ops to operate on the largest contigious vector.
-/// These patterns are useful when lowering to dialects with 1d vector type
-/// such as llvm and it will result fewer memory reads.
-void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Collect a set of patterns to collapse the most inner unit dims in xfer Ops
+///
+/// These patters reduce the rank of the operands of vector transfer ops to
+/// operate on vectors without trailing unit dims. This helps reduce the rank of
+/// the operands, which can be helpful when lowering to dialects that only
+/// support 1D vector type such as LLVM.
+void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns,
+                                                 PatternBenefit benefit = 1);
 
 /// Patterns that remove redundant Vector Ops by re-ordering them with
 /// e.g. elementwise Ops:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 36fc55f3f311d..bcaea1c79471f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2266,11 +2266,6 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
 
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  // TODO: Consider either:
-  //  * including DropInnerMostUnitDimsTransferRead and
-  //    DropInnerMostUnitDimsTransferWrite, or
-  //  * better naming to distinguish this and
-  //    populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
   patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
                DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
 }
@@ -2305,9 +2300,8 @@ void mlir::vector::populateVectorReductionToContractPatterns(
       patterns.getContext(), benefit);
 }
 
-void mlir::vector::
-    populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
-        RewritePatternSet &patterns, PatternBenefit benefit) {
+void mlir::vector::populateDropInnerMostUnitDimsXferOpPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DropInnerMostUnitDimsTransferRead,
                DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
                                                    benefit);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..461ca79340703 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,76 +1,210 @@
 // RUN: mlir-opt  -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
 
-// CHECK-LABEL: func @hoist_vector_transfer_pairs(
-//  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
-//  CHECK-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
-//  CHECK-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
-func.func @hoist_vector_transfer_pairs(
-    %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
-    %memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
-    %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
+///----------------------------------------------------------------------------------------
+/// Tests for vector.transfer_read + vector.transfer_write pairs
+///
+/// * Nested in double loops
+//  * Indices depend on induction variables
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func @mem_use_outside
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index)
+func.func @mem_use_outside(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:               %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:               scf.yield %[[USE]] : vector<1xf32>
+// CHECK:             }
+// CHECK:             vector.transfer_write %[[SCF]], %[[MEM]]{{\[}}%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:             "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
+// CHECK:           }
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %read = vector.transfer_read %mem[%i, %i], %pad: memref<?x?xf32>, vector<1xf32>
+      %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref<?x?xf32>
+    }
+  }
+  "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @mem_use_inside_outer_loop
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index)
+func.func @mem_use_inside_outer_loop(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:               %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:               scf.yield %[[USE]] : vector<1xf32>
+// CHECK:             }
+// CHECK:             vector.transfer_write %[[SCF]], %[[MEM]]{{\[}}%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
+// CHECK:           }
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %read = vector.transfer_read %mem[%i, %i], %pad: memref<?x?xf32>, vector<1xf32>
+      %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref<?x?xf32>
+    }
+    "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for vector.transfer_read + vector.transfer_write pairs
+///
+/// * Nested in double loops
+//  * Indices are constant
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_write
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index)
+func.func @negative_mem_use_inside_inner_loop_before_write(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.0 : f32
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:               %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:               %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:               "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
+// CHECK:               vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:             }
+// CHECK:           }
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %read = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
+      "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+      vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
 
-// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
-// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
-// CHECK:   vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
-// CHECK:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
-// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
-// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
-// CHECK:     "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
-// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK:     "some_use"(%[[MEMREF2]], %{{.*}}) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
-// CHECK:     "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
-// CHECK:     vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
-// CHECK:     vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32>
-// CHECK:     vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32>
-// CHECK:     "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> ()
-// CHECK:     scf.yield {{.*}} : vector<1xf32>, vector<2xf32>
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_after_write
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index)
+func.func @negative_mem_use_inside_inner_loop_after_write(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:               %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:               %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:               vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:               "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
+// CHECK:             }
+// CHECK:           }
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r3 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u3 = "val_use"(%r3) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u3, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+      "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+    }
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_read
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index)
+func.func @negative_mem_use_inside_inner_loop_before_read(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:     "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
+// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
+// CHECK:     "val_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:     vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
 // CHECK:   }
-// CHECK:   vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32>
-// CHECK:   "unrelated_use"(%[[MEMREF0]]) : (memref<?x?xf32>) -> ()
-// CHECK:   scf.yield {{.*}} : vector<1xf32>
 // CHECK: }
-// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
-// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref<?x?xf32>) -> ()
   scf.for %i = %lb to %ub step %step {
     scf.for %j = %lb to %ub step %step {
-      %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
-      %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
-      %r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
-      %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32>
-      "some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
-      %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
-      %r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref<?x?xf32>, vector<6xf32>
-      "some_crippling_use"(%memref5) : (memref<?x?xf32>) -> ()
-      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
-      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
-      %u2 = "some_use"(%memref2, %r2) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
-      %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
-      %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
-      %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
-      vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
-      vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
-      vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
-      vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
-      vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
-      vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref<?x?xf32>
-      "some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
+      "mem_use"(%mem) : (memref<?x?xf32>) -> ()
+      %read = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
     }
-    "unrelated_use"(%memref0) : (memref<?x?xf32>) -> ()
   }
-  "unrelated_use"(%memref1) : (memref<?x?xf32>) -> ()
   return
 }
 
@@ -86,6 +220,13 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// Other tests
+///
+/// TODO: Document
+///----------------------------------------------------------------------------------------
+
+
 // CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint(
 //  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 54aa96ba89a00..026dda46ecdc4 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -368,7 +368,7 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
+    populateDropInnerMostUnitDimsXferOpPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };

@joker-eph
Copy link
Collaborator

Can you fix the title/description please?

…ntiguousDimsPatterns`

Renames `populateVectorTransferCollapseInnerMostContiguousDimsPatterns`
as `populateDropInnerMostUnitDimsXferOpPatterns` + updates the
corresponding comments.

This addresses a TODO and makes the difference between these two
`populate*` methods clearer:
 * `populateDropUnitDimWithShapeCastPatterns`,
 * `populateDropInnerMostUnitDimsXferOpPatterns`,
@banach-space banach-space force-pushed the andrzej/vector/rename_patterns branch from c77283d to adfde79 Compare June 22, 2025 12:02
@banach-space banach-space changed the title andrzej/vector/rename patterns [mlir][vector][nfc] Rename populateVectorTransferCollapseInnerMostContiguousDimsPatterns Jun 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants