Skip to content

[mlir][vector] Allow multi dim vectors in vector.scatter #132217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 24, 2025

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented Mar 20, 2025

This patch matches the definition of vector.scatter as a counter part of vector.gather.

All of the changes done in this patch make vector.scatter match vector.gather 's multi dimensional definition.

Unrolling for vector.scatter will be implemented in subsequent patches.

Discourse Discussion: https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/13

@Groverkss
Copy link
Member Author

Depends on #132206

@llvmbot
Copy link
Member

llvmbot commented Mar 20, 2025

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

This patch matches the definition of vector.scatter as a counter part of vector.gather.

All of the changes done in this patch make vector.scatter match vector.gather 's multi dimensional definition.

Unrolling for vector.scatter will be implemented in subsequent patches.


Patch is 21.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132217.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+6-6)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+6-2)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+18-32)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+12-6)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (-46)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+37-3)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+10-8)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..5fab2ee1194e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2034,9 +2034,9 @@ def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$valueToStore)> {
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$valueToStore)> {
 
   let summary = [{
     scatters elements from a vector into memory as defined by an index vector
@@ -2044,9 +2044,9 @@ def Vector_ScatterOp :
   }];
 
   let description = [{
-    The scatter operation stores elements from a 1-D vector into memory as
-    defined by a base with indices and an additional 1-D index vector, but
-    only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+    The scatter operation stores elements from a n-D vector into memory as
+    defined by a base with indices and an additional n-D index vector, but
+    only if the corresponding bit in a n-D mask vector is set. Otherwise, no
     action is taken for that element. Informally the semantics are:
     ```
     if (mask[0]) base[index[0]] = value[0]
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..77d8b82b2bad0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -244,13 +244,17 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
 /// [FlattenGather]
 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
 ///
 /// [Gather1DToConditionalLoads]
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit = 1);
 
 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..4127f5b065bc8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,6 +269,10 @@ class VectorGatherOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = gather.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     auto loc = gather->getLoc();
 
     // Resolve alignment.
@@ -276,42 +280,21 @@ class VectorGatherOpConversion
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
+    // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value base = adaptor.getBase();
 
-    auto llvmNDVectorTy = adaptor.getIndexVec().getType();
     // Handle the simple case of 1-D vector.
-    if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
-      auto vType = gather.getVectorType();
-      // Resolve address.
-      Value ptrs =
-          getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                         base, ptr, adaptor.getIndexVec(), vType);
-      // Replace with the gather intrinsic.
-      rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-          gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
-          adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
-      return success();
-    }
-
-    const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
-    auto callback = [align, memRefType, base, ptr, loc, &rewriter,
-                     &typeConverter](Type llvm1DVectorTy,
-                                     ValueRange vectorOperands) {
-      // Resolve address.
-      Value ptrs = getIndexedPtrs(
-          rewriter, loc, typeConverter, memRefType, base, ptr,
-          /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
-      // Create the gather intrinsic.
-      return rewriter.create<LLVM::masked_gather>(
-          loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
-          /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
-    };
-    SmallVector<Value> vectorOperands = {
-        adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
-    return LLVM::detail::handleMultidimensionalVectors(
-        gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+    // Resolve address.
+    Value ptrs =
+        getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+                       base, ptr, adaptor.getIndexVec(), vType);
+    // Replace with the gather intrinsic.
+    rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+        gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+        adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+    return success();
   }
 };
 
@@ -330,13 +313,16 @@ class VectorScatterOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = scatter.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
     // Resolve address.
-    VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value ptrs =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..7082b92c95d1d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorInsertExtractStridedSliceTransforms(patterns);
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
+    populateVectorGatherLoweringPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..59da2ebe4aae0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
     return emitOpError("base and valueToStore element type should match");
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
-  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
+  if (valueVType.getShape() != indVType.getShape())
     return emitOpError("expected valueToStore dim to match indices dim");
-  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+  if (valueVType.getShape() != maskVType.getShape())
     return emitOpError("expected valueToStore dim to match mask dim");
   return success();
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..eff8ee0e9de7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -38,7 +38,7 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension. For example:
 /// ```
 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
 ///
 /// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::GatherOp op,
                                 PatternRewriter &rewriter) const override {
     VectorType resultTy = op.getType();
     if (resultTy.getRank() < 2)
-      return rewriter.notifyMatchFailure(op, "already flat");
+      return rewriter.notifyMatchFailure(op, "already 1-D");
 
     // Unrolling doesn't take vscale into account. Pattern is disabled for
     // vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
 /// ```mlir
 ///   %subview = memref.subview %M (...)
 ///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+///   strided<[3]>>
 /// ```
 /// ==>
 /// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
 void mlir::vector::populateVectorGatherLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
-               Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+  patterns.add<UnrollGather>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+      patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
 
 // -----
 
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
-  return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
-  return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
 
 func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..f5c722e29420c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1654,7 +1654,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 //===----------------------------------------------------------------------===//
 // vector.gather
-//
+// 
 // NOTE: vector.constant_mask won't lower with
 //  * --convert-to-llvm="filter-dialects=vector",
 // hence testing here.
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
   %0 = arith.constant 0: index
-  %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
   return %2 : vector<2x3xf32>
 }
@@ -1679,7 +1679,7 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
   // vector.constant_mask only supports 'none set' or 'all set' scalable
   // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
   // width vectors above.
-  %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
+  %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
   return %2 : vector<2x[3]xf32>
 }
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+// Multi-Dimensional scatters are not supported yet. Check that we do not lower
+// them.
+
+func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask
+// CHECK: vector.scatter
+
+// -----
+
+func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
+  %0 = arith.constant 0: index
+  // vector.constant_mask only supports 'none set' or 'all set' scalable
+  // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+  // width vectors above.
+  %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask_scalable
+// CHECK: vector.scatter
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.interleave
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..1b89e8eb5069b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
 func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                             %mask: vector<16xi1>, %value: vector<2x16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error@+1 {{'vector.scatter' op operand #4 must be  of ranks 1, but got 'vector<2x16xf32>'}}
+  // expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
   vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
 }
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456d..279fd3e522775 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
   return
 }
 
+// CHECK-LABEL: @gather_multi_dims
+func.func @gather_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+  vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+  return %0 : vector<2x16xf32>
+}
+
 // CHECK-LABEL: @gather_on_tensor
 func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf3...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 20, 2025

@llvm/pr-subscribers-mlir-vector

Author: Kunwar Grover (Groverkss)

Changes

This patch matches the definition of vector.scatter as a counter part of vector.gather.

All of the changes done in this patch make vector.scatter match vector.gather 's multi dimensional definition.

Unrolling for vector.scatter will be implemented in subsequent patches.


Patch is 21.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132217.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+6-6)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+6-2)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+18-32)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+12-6)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (-46)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+37-3)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+10-8)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..5fab2ee1194e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2034,9 +2034,9 @@ def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$valueToStore)> {
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$valueToStore)> {
 
   let summary = [{
     scatters elements from a vector into memory as defined by an index vector
@@ -2044,9 +2044,9 @@ def Vector_ScatterOp :
   }];
 
   let description = [{
-    The scatter operation stores elements from a 1-D vector into memory as
-    defined by a base with indices and an additional 1-D index vector, but
-    only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+    The scatter operation stores elements from a n-D vector into memory as
+    defined by a base with indices and an additional n-D index vector, but
+    only if the corresponding bit in a n-D mask vector is set. Otherwise, no
     action is taken for that element. Informally the semantics are:
     ```
     if (mask[0]) base[index[0]] = value[0]
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..77d8b82b2bad0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -244,13 +244,17 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
 /// [FlattenGather]
 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
 ///
 /// [Gather1DToConditionalLoads]
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit = 1);
 
 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..4127f5b065bc8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,6 +269,10 @@ class VectorGatherOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = gather.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     auto loc = gather->getLoc();
 
     // Resolve alignment.
@@ -276,42 +280,21 @@ class VectorGatherOpConversion
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
+    // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value base = adaptor.getBase();
 
-    auto llvmNDVectorTy = adaptor.getIndexVec().getType();
     // Handle the simple case of 1-D vector.
-    if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
-      auto vType = gather.getVectorType();
-      // Resolve address.
-      Value ptrs =
-          getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                         base, ptr, adaptor.getIndexVec(), vType);
-      // Replace with the gather intrinsic.
-      rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-          gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
-          adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
-      return success();
-    }
-
-    const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
-    auto callback = [align, memRefType, base, ptr, loc, &rewriter,
-                     &typeConverter](Type llvm1DVectorTy,
-                                     ValueRange vectorOperands) {
-      // Resolve address.
-      Value ptrs = getIndexedPtrs(
-          rewriter, loc, typeConverter, memRefType, base, ptr,
-          /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
-      // Create the gather intrinsic.
-      return rewriter.create<LLVM::masked_gather>(
-          loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
-          /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
-    };
-    SmallVector<Value> vectorOperands = {
-        adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
-    return LLVM::detail::handleMultidimensionalVectors(
-        gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+    // Resolve address.
+    Value ptrs =
+        getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+                       base, ptr, adaptor.getIndexVec(), vType);
+    // Replace with the gather intrinsic.
+    rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+        gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+        adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+    return success();
   }
 };
 
@@ -330,13 +313,16 @@ class VectorScatterOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = scatter.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
     // Resolve address.
-    VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value ptrs =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..7082b92c95d1d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorInsertExtractStridedSliceTransforms(patterns);
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
+    populateVectorGatherLoweringPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..59da2ebe4aae0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
     return emitOpError("base and valueToStore element type should match");
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
-  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
+  if (valueVType.getShape() != indVType.getShape())
     return emitOpError("expected valueToStore dim to match indices dim");
-  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+  if (valueVType.getShape() != maskVType.getShape())
     return emitOpError("expected valueToStore dim to match mask dim");
   return success();
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..eff8ee0e9de7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -38,7 +38,7 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension. For example:
 /// ```
 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
 ///
 /// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::GatherOp op,
                                 PatternRewriter &rewriter) const override {
     VectorType resultTy = op.getType();
     if (resultTy.getRank() < 2)
-      return rewriter.notifyMatchFailure(op, "already flat");
+      return rewriter.notifyMatchFailure(op, "already 1-D");
 
     // Unrolling doesn't take vscale into account. Pattern is disabled for
     // vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
 /// ```mlir
 ///   %subview = memref.subview %M (...)
 ///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+///   strided<[3]>>
 /// ```
 /// ==>
 /// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
 void mlir::vector::populateVectorGatherLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
-               Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+  patterns.add<UnrollGather>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+      patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
 
 // -----
 
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
-  return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
-  return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
 
 func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..f5c722e29420c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1654,7 +1654,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 //===----------------------------------------------------------------------===//
 // vector.gather
-//
+// 
 // NOTE: vector.constant_mask won't lower with
 //  * --convert-to-llvm="filter-dialects=vector",
 // hence testing here.
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
   %0 = arith.constant 0: index
-  %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
   return %2 : vector<2x3xf32>
 }
@@ -1679,7 +1679,7 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
   // vector.constant_mask only supports 'none set' or 'all set' scalable
   // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
   // width vectors above.
-  %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
+  %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
   return %2 : vector<2x[3]xf32>
 }
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+// Multi-Dimensional scatters are not supported yet. Check that we do not lower
+// them.
+
+func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask
+// CHECK: vector.scatter
+
+// -----
+
+func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
+  %0 = arith.constant 0: index
+  // vector.constant_mask only supports 'none set' or 'all set' scalable
+  // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+  // width vectors above.
+  %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask_scalable
+// CHECK: vector.scatter
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.interleave
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..1b89e8eb5069b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
 func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                             %mask: vector<16xi1>, %value: vector<2x16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error@+1 {{'vector.scatter' op operand #4 must be  of ranks 1, but got 'vector<2x16xf32>'}}
+  // expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
   vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
 }
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456d..279fd3e522775 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
   return
 }
 
+// CHECK-LABEL: @gather_multi_dims
+func.func @gather_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+  vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+  return %0 : vector<2x16xf32>
+}
+
 // CHECK-LABEL: @gather_on_tensor
 func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf3...
[truncated]

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. LGTM after #132206 lands

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aligning vector.gather + vector.scatter makes a lot of sense, thanks for taking care of that.

I've left some minor comments. Also, there should be new tests in https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/ops.mlir

Thanks!

@Groverkss
Copy link
Member Author

Groverkss commented Mar 21, 2025

Hi @banach-space , this pr as mentioned above depends on #132206 , which is why you may be seeing unrelated changes here. Github doesnt have the best way of having stacked prs, so it's hard to indicate this.

Sorry for that. Please only review the topmost commit. The earlier commits are from depending prs. I will add a comment to other prs as well.

@banach-space
Copy link
Contributor

Hi @banach-space , this pr as mentioned above depends on #132206 , which is why you may be seeing unrelated changes here. Github doesnt have the best way of having stacked prs, so it's hard to indicate this.

Please either document such things in the summary, or use https://llvm.org/docs/GitHub.html#using-graphite-for-stacked-pull-requests (I've not used it myself).

Also, if you are sending multiple dependent PRs to achieve a larger goal, could you create a GitHub issue to summarise and to communicate that? Otherwise it's hard to track what's going on (and to see the bigger picture). Specifically, I assume that these are related:

? If yes, what's the order and the dependencies? Could this be documented outside this PR?

@makslevental
Copy link
Contributor

makslevental commented Mar 21, 2025

Also, if you are sending multiple dependent PRs to achieve a larger goal, could you create a GitHub issue to summarise and to communicate that? Otherwise it's hard to track what's going on (and to see the bigger picture)

This is not as far as I am aware current policy? Going back several pages in the issues I see no such tracking issues. In addition, given that there's 5k+ issues currently, I think such issues (such a policy) are utterly useless.

@banach-space
Copy link
Contributor

banach-space commented Mar 21, 2025

This is not as far as I am aware current policy?

That wasn’t a reference to a formal policy — just a kind request.

GitHub is admittedly not ideal for stacked PRs, but there are some lightweight practices that make a big difference. One common approach is to mention dependencies explicitly in the PR summary. For example:

given that there's 5k+ issues currently, I think such issues

I’m not sure I follow the connection. The issue volume is definitely a challenge, but that's all the more reason to find ways to improve clarity and coordination. Open to other suggestions too!

@makslevental
Copy link
Contributor

makslevental commented Mar 21, 2025

I’m not sure I follow the connection.

If the role of tracking issues is to centralize information about a workstream then they need to be in a central place so that they are easier to find than the threads of the workstream itself.

Since we have 5k+ issues, in order to find a tracking issue, one will have to either:

  1. scroll 5k+ issues;
  2. remember the name of the tracking issue;
  3. go to one of the threads of the workstream (PRs) that links to the tracking issue in order to jump to the tracking issue.

That last bullet should lucidly illustrate the utter uselessness of the proposition (since it entails annotating each PR of the workstream with a link to the central tracking issue in order to enable finding the tracking issue that tracks the PRs one just used to find the tracking issue).

One common approach is to mention dependencies explicitly in the PR summary.

There is a comment literally right below the summary that says

image

Open to other suggestions too!

My suggestion is that we focus on doing the job we get paid for (i.e., landing PRs that meaningfully contribute to LLVM) instead of further bureaucratizing the process and thereby actively discouraging/impeding meaningful contribution.

@Groverkss
Copy link
Member Author

Groverkss commented Mar 23, 2025

Hi @banach-space , this pr as mentioned above depends on #132206 , which is why you may be seeing unrelated changes here. Github doesnt have the best way of having stacked prs, so it's hard to indicate this.

Please either document such things in the summary, or use https://llvm.org/docs/GitHub.html#using-graphite-for-stacked-pull-requests (I've not used it myself).

Also, if you are sending multiple dependent PRs to achieve a larger goal, could you create a GitHub issue to summarise and to communicate that? Otherwise it's hard to track what's going on (and to see the bigger picture). Specifically, I assume that these are related:

? If yes, what's the order and the dependencies? Could this be documented outside this PR?

I added a link to the discourse discussion. All PRs are listed there in order of dependency and how reviewers are expected to review them. I don't think a github issue has enough context than the discourse thread, which is why i prefered the discourse comment.

Let's get back to technical discussion now. I'm happy to create a github issue or use graphite or anything that the reviewers prefer if it will make it easier for them to review it. For now, the discourse discussion link makes it clear. Please let me know if anything else is needed.

Thanks for the review so far!

@banach-space
Copy link
Contributor

Hi Kunwar,

Thanks for updating the PR with the Discourse link — that definitely helps clarify the context and dependencies. And no worries about the format; I appreciate your flexibility on whether to use a GitHub issue, Graphite, or Discourse.

Also, apologies — I missed the note below the summary referencing the other PR. That was my oversight; I scrolled past it too quickly.

As you can see from the PRs I linked, contributors often mention dependencies directly in the PR summary. That’s where I was expecting to find it too, but clearly there’s no consistent practice yet. To help with standardization and reduce this kind of confusion in the future, I’ve submitted:

Hopefully once that lands, it’ll make things clearer for both authors and reviewers.

One additional note, directed more generally to the thread: some earlier phrasing (e.g., referring to a suggestion as “utterly useless”) felt unnecessarily dismissive. While I’m sure no harm was intended, I want to remind everyone that LLVM’s Code of Conduct encourages us to keep discussions professional and respectful, even when we disagree. That’s essential for keeping the project collaborative and welcoming for all contributors.

Thanks again — looking forward to reviewing the next PRs in the stack!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nits, thanks!

@Groverkss Groverkss merged commit 24a8e18 into llvm:main Mar 24, 2025
8 of 10 checks passed
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants