Skip to content

[mlir][vector] Decouple unrolling gather and gather to llvm lowering #132206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 24, 2025

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented Mar 20, 2025

This patch decouples unrolling vector.gather and lowering vector.gather to llvm.masked.gather.

This is consistent with how vector.load, vector.store, vector.maskedload, vector.maskedstore lower to LLVM.

Some interesting test changes from this patch:

  • 2D vector.gather lowering to llvm tests are deleted. This is consistent with other memory load/store ops.
  • There are still tests for 2D vector.gather, but the constant mask for these test is modified. This is because with the updated lowering, one of the unrolled vector.gather disappears because it is masked off (also demonstrating why this is a better lowering path)

Overall, this makes vector.gather take the same consistent path for lowering to LLVM as other load/store ops.

Discourse Discussion: https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/13

@llvmbot
Copy link
Member

llvmbot commented Mar 20, 2025

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

This patch decouples unrolling vector.gather and lowering vector.gather to llvm.masked.gather.

This is consistent with how vector.load, vector.store, vector.maskedload, vector.maskedstore lower to LLVM.

Some interesting test changes from this patch:

  • 2D vector.gather lowering to llvm tests are deleted. This is consistent with other memory load/store ops.
  • There are still tests for 2D vector.gather, but the constant mask for these test is modified. This is because with the updated lowering, one of the unrolled vector.gather disappears because it is masked off (also demonstrating why this is a better lowering path)

Overall, this makes vector.gather take the same consistent path for lowering to LLVM as other load/store ops.


Full diff: https://github.com/llvm/llvm-project/pull/132206.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+6-2)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+14-31)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+12-6)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (-46)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+2-2)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..77d8b82b2bad0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -244,13 +244,17 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
 /// [FlattenGather]
 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
 ///
 /// [Gather1DToConditionalLoads]
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit = 1);
 
 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..148c2ae26be35 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,6 +269,10 @@ class VectorGatherOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = gather.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     auto loc = gather->getLoc();
 
     // Resolve alignment.
@@ -276,42 +280,21 @@ class VectorGatherOpConversion
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
+    // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value base = adaptor.getBase();
 
-    auto llvmNDVectorTy = adaptor.getIndexVec().getType();
     // Handle the simple case of 1-D vector.
-    if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
-      auto vType = gather.getVectorType();
-      // Resolve address.
-      Value ptrs =
-          getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                         base, ptr, adaptor.getIndexVec(), vType);
-      // Replace with the gather intrinsic.
-      rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-          gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
-          adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
-      return success();
-    }
-
-    const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
-    auto callback = [align, memRefType, base, ptr, loc, &rewriter,
-                     &typeConverter](Type llvm1DVectorTy,
-                                     ValueRange vectorOperands) {
-      // Resolve address.
-      Value ptrs = getIndexedPtrs(
-          rewriter, loc, typeConverter, memRefType, base, ptr,
-          /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
-      // Create the gather intrinsic.
-      return rewriter.create<LLVM::masked_gather>(
-          loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
-          /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
-    };
-    SmallVector<Value> vectorOperands = {
-        adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
-    return LLVM::detail::handleMultidimensionalVectors(
-        gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+    // Resolve address.
+    Value ptrs =
+        getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+                       base, ptr, adaptor.getIndexVec(), vType);
+    // Replace with the gather intrinsic.
+    rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+        gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+        adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+    return success();
   }
 };
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..7082b92c95d1d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorInsertExtractStridedSliceTransforms(patterns);
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
+    populateVectorGatherLoweringPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..eff8ee0e9de7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -38,7 +38,7 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension. For example:
 /// ```
 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
 ///
 /// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::GatherOp op,
                                 PatternRewriter &rewriter) const override {
     VectorType resultTy = op.getType();
     if (resultTy.getRank() < 2)
-      return rewriter.notifyMatchFailure(op, "already flat");
+      return rewriter.notifyMatchFailure(op, "already 1-D");
 
     // Unrolling doesn't take vscale into account. Pattern is disabled for
     // vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
 /// ```mlir
 ///   %subview = memref.subview %M (...)
 ///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+///   strided<[3]>>
 /// ```
 /// ==>
 /// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
 void mlir::vector::populateVectorGatherLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
-               Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+  patterns.add<UnrollGather>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+      patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
 
 // -----
 
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
-  return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
-  return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
 
 func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..4a36294b355e7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
   %0 = arith.constant 0: index
-  %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
   return %2 : vector<2x3xf32>
 }
@@ -1679,7 +1679,7 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
   // vector.constant_mask only supports 'none set' or 'all set' scalable
   // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
   // width vectors above.
-  %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
+  %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
   return %2 : vector<2x[3]xf32>
 }
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 74838bc0ca2fb..2cf1dde9bd1b8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -782,6 +782,7 @@ struct TestVectorGatherLowering
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateVectorGatherLoweringPatterns(patterns);
+    populateVectorGatherToConditionalLoadPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };

@llvmbot
Copy link
Member

llvmbot commented Mar 20, 2025

@llvm/pr-subscribers-mlir-vector

Author: Kunwar Grover (Groverkss)

Changes

This patch decouples unrolling vector.gather and lowering vector.gather to llvm.masked.gather.

This is consistent with how vector.load, vector.store, vector.maskedload, vector.maskedstore lower to LLVM.

Some interesting test changes from this patch:

  • 2D vector.gather lowering to llvm tests are deleted. This is consistent with other memory load/store ops.
  • There are still tests for 2D vector.gather, but the constant mask for these test is modified. This is because with the updated lowering, one of the unrolled vector.gather disappears because it is masked off (also demonstrating why this is a better lowering path)

Overall, this makes vector.gather take the same consistent path for lowering to LLVM as other load/store ops.


Full diff: https://github.com/llvm/llvm-project/pull/132206.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+6-2)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+14-31)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+12-6)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (-46)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+2-2)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..77d8b82b2bad0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -244,13 +244,17 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
 /// [FlattenGather]
 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
 ///
 /// [Gather1DToConditionalLoads]
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit = 1);
 
 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..148c2ae26be35 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,6 +269,10 @@ class VectorGatherOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = gather.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     auto loc = gather->getLoc();
 
     // Resolve alignment.
@@ -276,42 +280,21 @@ class VectorGatherOpConversion
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
+    // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value base = adaptor.getBase();
 
-    auto llvmNDVectorTy = adaptor.getIndexVec().getType();
     // Handle the simple case of 1-D vector.
-    if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
-      auto vType = gather.getVectorType();
-      // Resolve address.
-      Value ptrs =
-          getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                         base, ptr, adaptor.getIndexVec(), vType);
-      // Replace with the gather intrinsic.
-      rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-          gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
-          adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
-      return success();
-    }
-
-    const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
-    auto callback = [align, memRefType, base, ptr, loc, &rewriter,
-                     &typeConverter](Type llvm1DVectorTy,
-                                     ValueRange vectorOperands) {
-      // Resolve address.
-      Value ptrs = getIndexedPtrs(
-          rewriter, loc, typeConverter, memRefType, base, ptr,
-          /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
-      // Create the gather intrinsic.
-      return rewriter.create<LLVM::masked_gather>(
-          loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
-          /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
-    };
-    SmallVector<Value> vectorOperands = {
-        adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
-    return LLVM::detail::handleMultidimensionalVectors(
-        gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+    // Resolve address.
+    Value ptrs =
+        getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+                       base, ptr, adaptor.getIndexVec(), vType);
+    // Replace with the gather intrinsic.
+    rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+        gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+        adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+    return success();
   }
 };
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..7082b92c95d1d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorInsertExtractStridedSliceTransforms(patterns);
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
+    populateVectorGatherLoweringPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..eff8ee0e9de7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -38,7 +38,7 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension. For example:
 /// ```
 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
 ///
 /// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::GatherOp op,
                                 PatternRewriter &rewriter) const override {
     VectorType resultTy = op.getType();
     if (resultTy.getRank() < 2)
-      return rewriter.notifyMatchFailure(op, "already flat");
+      return rewriter.notifyMatchFailure(op, "already 1-D");
 
     // Unrolling doesn't take vscale into account. Pattern is disabled for
     // vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
 /// ```mlir
 ///   %subview = memref.subview %M (...)
 ///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+///   strided<[3]>>
 /// ```
 /// ==>
 /// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
 void mlir::vector::populateVectorGatherLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
-               Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+  patterns.add<UnrollGather>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+      patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
 
 // -----
 
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
-  return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
-  return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
 
 func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..4a36294b355e7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
   %0 = arith.constant 0: index
-  %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
   return %2 : vector<2x3xf32>
 }
@@ -1679,7 +1679,7 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
   // vector.constant_mask only supports 'none set' or 'all set' scalable
   // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
   // width vectors above.
-  %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
+  %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
   return %2 : vector<2x[3]xf32>
 }
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 74838bc0ca2fb..2cf1dde9bd1b8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -782,6 +782,7 @@ struct TestVectorGatherLowering
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateVectorGatherLoweringPatterns(patterns);
+    populateVectorGatherToConditionalLoadPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, this makes sense to me. Give it some time before landing for others to take a look in case they are relying on multi-dimensional gather lowering cases.

@Groverkss
Copy link
Member Author

Landing this to make progress on subsequent patches. This is a fairly non controversial change which effectively does the same thing, just earlier, to make things more consistent. Happy to address more comments post review in accordance to https://llvm.org/docs/CodeReview.html#can-code-be-reviewed-after-it-is-committed

@Groverkss Groverkss merged commit cf0efb3 into llvm:main Mar 24, 2025
8 of 11 checks passed
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {

func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%0 = arith.constant 0: index
%1 = vector.constant_mask [1, 2] : vector<2x3xi1>
%1 = vector.constant_mask [2, 2] : vector<2x3xi1>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this test modified? I am asking as previously one of the outer lanes was masked out and currently it isn't. Is this significant?

Copy link
Member Author

@Groverkss Groverkss Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, this is mentioned in the PR summary:

There are still tests for 2D vector.gather, but the constant mask for these test is modified. This is because with the updated lowering, one of the unrolled vector.gather disappears because it is masked off (also demonstrating why this is a better lowering path)

It's to make sure that unrolling is actually producing 2 unrolled gathers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, now I see what you meant, thanks!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing this now, cool! :)

krzysz00 added a commit to iree-org/iree that referenced this pull request Apr 1, 2025
- Add calls to `populateVectorGatherLoweringPatterns` and
`populateVectorGatherToConditionalLoadPatterns` as needed to handle
upstream llvm/llvm-project#132206
- Update ROCDL tests to account for
llvm/llvm-project#131850
- Disable `tosa.reduce_prod` and `tosa.clamp` tests on integers
temporarily since they're failing Tosa verification - see #20422

Explicitly set `amdhsa_code_object_version` when constructing ROCm LLVM
modules, explicitly pass `-mcode-object-version=N` to microkernel
compiles, and explicitly set the `__oclc_ABI_version` global, so that we
don't get hit with errors about being unable to run our binaries from
the compiler default code object version getting bumped further than
what our users or CI support.

The default code object version in LLVM as of this commit is 6, we pin
to 5. The followup work to update to v6 is #20423 .
bangtianliu pushed a commit to bangtianliu/iree that referenced this pull request Apr 8, 2025
- Add calls to `populateVectorGatherLoweringPatterns` and
`populateVectorGatherToConditionalLoadPatterns` as needed to handle
upstream llvm/llvm-project#132206
- Update ROCDL tests to account for
llvm/llvm-project#131850
- Disable `tosa.reduce_prod` and `tosa.clamp` tests on integers
temporarily since they're failing Tosa verification - see iree-org#20422

Explicitly set `amdhsa_code_object_version` when constructing ROCm LLVM
modules, explicitly pass `-mcode-object-version=N` to microkernel
compiles, and explicitly set the `__oclc_ABI_version` global, so that we
don't get hit with errors about being unable to run our binaries from
the compiler default code object version getting bumped further than
what our users or CI support.

The default code object version in LLVM as of this commit is 6, we pin
to 5. The followup work to update to v6 is iree-org#20423 .
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants