-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][vector] Allow multi dim vectors in vector.scatter #132217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Depends on #132206 |
@llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesThis patch matches the definition of vector.scatter as a counter part of vector.gather. All of the changes done in this patch make vector.scatter match vector.gather 's multi dimensional definition. Unrolling for vector.scatter will be implemented in subsequent patches. Patch is 21.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132217.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..5fab2ee1194e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2034,9 +2034,9 @@ def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
- VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$valueToStore)> {
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$valueToStore)> {
let summary = [{
scatters elements from a vector into memory as defined by an index vector
@@ -2044,9 +2044,9 @@ def Vector_ScatterOp :
}];
let description = [{
- The scatter operation stores elements from a 1-D vector into memory as
- defined by a base with indices and an additional 1-D index vector, but
- only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+ The scatter operation stores elements from a n-D vector into memory as
+ defined by a base with indices and an additional n-D index vector, but
+ only if the corresponding bit in a n-D mask vector is set. Otherwise, no
action is taken for that element. Informally the semantics are:
```
if (mask[0]) base[index[0]] = value[0]
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..77d8b82b2bad0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -244,13 +244,17 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
/// [FlattenGather]
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
///
/// [Gather1DToConditionalLoads]
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..4127f5b065bc8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,6 +269,10 @@ class VectorGatherOpConversion
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
+ VectorType vType = gather.getVectorType();
+ if (vType.getRank() > 1)
+ return failure();
+
auto loc = gather->getLoc();
// Resolve alignment.
@@ -276,42 +280,21 @@ class VectorGatherOpConversion
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
+ // Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value base = adaptor.getBase();
- auto llvmNDVectorTy = adaptor.getIndexVec().getType();
// Handle the simple case of 1-D vector.
- if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
- auto vType = gather.getVectorType();
- // Resolve address.
- Value ptrs =
- getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- base, ptr, adaptor.getIndexVec(), vType);
- // Replace with the gather intrinsic.
- rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
- adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
- return success();
- }
-
- const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
- auto callback = [align, memRefType, base, ptr, loc, &rewriter,
- &typeConverter](Type llvm1DVectorTy,
- ValueRange vectorOperands) {
- // Resolve address.
- Value ptrs = getIndexedPtrs(
- rewriter, loc, typeConverter, memRefType, base, ptr,
- /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
- // Create the gather intrinsic.
- return rewriter.create<LLVM::masked_gather>(
- loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
- /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
- };
- SmallVector<Value> vectorOperands = {
- adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
- return LLVM::detail::handleMultidimensionalVectors(
- gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+ // Resolve address.
+ Value ptrs =
+ getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+ base, ptr, adaptor.getIndexVec(), vType);
+ // Replace with the gather intrinsic.
+ rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+ gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+ adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+ return success();
}
};
@@ -330,13 +313,16 @@ class VectorScatterOpConversion
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
+ VectorType vType = scatter.getVectorType();
+ if (vType.getRank() > 1)
+ return failure();
+
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
- VectorType vType = scatter.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptrs =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..7082b92c95d1d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
+ populateVectorGatherLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..59da2ebe4aae0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
return emitOpError("base and valueToStore element type should match");
if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
- if (valueVType.getDimSize(0) != indVType.getDimSize(0))
+ if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
- if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+ if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..eff8ee0e9de7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -38,7 +38,7 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension. For example:
/// ```
/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
///
/// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
VectorType resultTy = op.getType();
if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already flat");
+ return rewriter.notifyMatchFailure(op, "already 1-D");
// Unrolling doesn't take vscale into account. Pattern is disabled for
// vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
/// ```mlir
/// %subview = memref.subview %M (...)
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+/// strided<[3]>>
/// ```
/// ==>
/// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
void mlir::vector::populateVectorGatherLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<FlattenGather, RemoveStrideFromGatherSource,
- Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+ patterns.add<UnrollGather>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
// -----
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
- return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
- return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..f5c722e29420c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1654,7 +1654,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
//===----------------------------------------------------------------------===//
// vector.gather
-//
+//
// NOTE: vector.constant_mask won't lower with
// * --convert-to-llvm="filter-dialects=vector",
// hence testing here.
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%0 = arith.constant 0: index
- %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+ %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
return %2 : vector<2x3xf32>
}
@@ -1679,7 +1679,7 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
// vector.constant_mask only supports 'none set' or 'all set' scalable
// dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
// width vectors above.
- %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
+ %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
return %2 : vector<2x[3]xf32>
}
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
// -----
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+// Multi-Dimensional scatters are not supported yet. Check that we do not lower
+// them.
+
+func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
+ %0 = arith.constant 0: index
+ %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
+ vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_with_mask
+// CHECK: vector.scatter
+
+// -----
+
+func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
+ %0 = arith.constant 0: index
+ // vector.constant_mask only supports 'none set' or 'all set' scalable
+ // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+ // width vectors above.
+ %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
+ vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_with_mask_scalable
+// CHECK: vector.scatter
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.interleave
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..1b89e8eb5069b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<2x16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error@+1 {{'vector.scatter' op operand #4 must be of ranks 1, but got 'vector<2x16xf32>'}}
+ // expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
vector.scatter %base[%c0][%indices], %mask, %value
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456d..279fd3e522775 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
return
}
+// CHECK-LABEL: @gather_multi_dims
+func.func @gather_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+ // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+ vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+ return %0 : vector<2x16xf32>
+}
+
// CHECK-LABEL: @gather_on_tensor
func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf3...
[truncated]
|
@llvm/pr-subscribers-mlir-vector Author: Kunwar Grover (Groverkss) ChangesThis patch matches the definition of vector.scatter as a counter part of vector.gather. All of the changes done in this patch make vector.scatter match vector.gather 's multi dimensional definition. Unrolling for vector.scatter will be implemented in subsequent patches. Patch is 21.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132217.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..5fab2ee1194e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2034,9 +2034,9 @@ def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
- VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$valueToStore)> {
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$valueToStore)> {
let summary = [{
scatters elements from a vector into memory as defined by an index vector
@@ -2044,9 +2044,9 @@ def Vector_ScatterOp :
}];
let description = [{
- The scatter operation stores elements from a 1-D vector into memory as
- defined by a base with indices and an additional 1-D index vector, but
- only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+ The scatter operation stores elements from a n-D vector into memory as
+ defined by a base with indices and an additional n-D index vector, but
+ only if the corresponding bit in a n-D mask vector is set. Otherwise, no
action is taken for that element. Informally the semantics are:
```
if (mask[0]) base[index[0]] = value[0]
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..77d8b82b2bad0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -244,13 +244,17 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
/// [FlattenGather]
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
///
/// [Gather1DToConditionalLoads]
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..4127f5b065bc8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,6 +269,10 @@ class VectorGatherOpConversion
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
+ VectorType vType = gather.getVectorType();
+ if (vType.getRank() > 1)
+ return failure();
+
auto loc = gather->getLoc();
// Resolve alignment.
@@ -276,42 +280,21 @@ class VectorGatherOpConversion
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
+ // Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value base = adaptor.getBase();
- auto llvmNDVectorTy = adaptor.getIndexVec().getType();
// Handle the simple case of 1-D vector.
- if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
- auto vType = gather.getVectorType();
- // Resolve address.
- Value ptrs =
- getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- base, ptr, adaptor.getIndexVec(), vType);
- // Replace with the gather intrinsic.
- rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
- adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
- return success();
- }
-
- const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
- auto callback = [align, memRefType, base, ptr, loc, &rewriter,
- &typeConverter](Type llvm1DVectorTy,
- ValueRange vectorOperands) {
- // Resolve address.
- Value ptrs = getIndexedPtrs(
- rewriter, loc, typeConverter, memRefType, base, ptr,
- /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
- // Create the gather intrinsic.
- return rewriter.create<LLVM::masked_gather>(
- loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
- /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
- };
- SmallVector<Value> vectorOperands = {
- adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
- return LLVM::detail::handleMultidimensionalVectors(
- gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+ // Resolve address.
+ Value ptrs =
+ getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+ base, ptr, adaptor.getIndexVec(), vType);
+ // Replace with the gather intrinsic.
+ rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+ gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+ adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+ return success();
}
};
@@ -330,13 +313,16 @@ class VectorScatterOpConversion
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
+ VectorType vType = scatter.getVectorType();
+ if (vType.getRank() > 1)
+ return failure();
+
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
- VectorType vType = scatter.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptrs =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..7082b92c95d1d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
+ populateVectorGatherLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..59da2ebe4aae0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
return emitOpError("base and valueToStore element type should match");
if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
- if (valueVType.getDimSize(0) != indVType.getDimSize(0))
+ if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
- if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+ if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..eff8ee0e9de7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -38,7 +38,7 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension. For example:
/// ```
/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
///
/// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
VectorType resultTy = op.getType();
if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already flat");
+ return rewriter.notifyMatchFailure(op, "already 1-D");
// Unrolling doesn't take vscale into account. Pattern is disabled for
// vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
/// ```mlir
/// %subview = memref.subview %M (...)
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+/// strided<[3]>>
/// ```
/// ==>
/// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
void mlir::vector::populateVectorGatherLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<FlattenGather, RemoveStrideFromGatherSource,
- Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+ patterns.add<UnrollGather>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
// -----
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
- return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
- return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..f5c722e29420c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1654,7 +1654,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
//===----------------------------------------------------------------------===//
// vector.gather
-//
+//
// NOTE: vector.constant_mask won't lower with
// * --convert-to-llvm="filter-dialects=vector",
// hence testing here.
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%0 = arith.constant 0: index
- %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+ %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
return %2 : vector<2x3xf32>
}
@@ -1679,7 +1679,7 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
// vector.constant_mask only supports 'none set' or 'all set' scalable
// dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
// width vectors above.
- %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
+ %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
return %2 : vector<2x[3]xf32>
}
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
// -----
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+// Multi-Dimensional scatters are not supported yet. Check that we do not lower
+// them.
+
+func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
+ %0 = arith.constant 0: index
+ %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
+ vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_with_mask
+// CHECK: vector.scatter
+
+// -----
+
+func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
+ %0 = arith.constant 0: index
+ // vector.constant_mask only supports 'none set' or 'all set' scalable
+ // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+ // width vectors above.
+ %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
+ vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_with_mask_scalable
+// CHECK: vector.scatter
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.interleave
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..1b89e8eb5069b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<2x16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error@+1 {{'vector.scatter' op operand #4 must be of ranks 1, but got 'vector<2x16xf32>'}}
+ // expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
vector.scatter %base[%c0][%indices], %mask, %value
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456d..279fd3e522775 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
return
}
+// CHECK-LABEL: @gather_multi_dims
+func.func @gather_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+ // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+ vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+ return %0 : vector<2x16xf32>
+}
+
// CHECK-LABEL: @gather_on_tensor
func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf3...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me. LGTM after #132206 lands
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aligning vector.gather
+ vector.scatter
makes a lot of sense, thanks for taking care of that.
I've left some minor comments. Also, there should be new tests in https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/ops.mlir
Thanks!
mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
Outdated
Show resolved
Hide resolved
Hi @banach-space , this pr as mentioned above depends on #132206 , which is why you may be seeing unrelated changes here. Github doesnt have the best way of having stacked prs, so it's hard to indicate this. Sorry for that. Please only review the topmost commit. The earlier commits are from depending prs. I will add a comment to other prs as well. |
Please either document such things in the summary, or use https://llvm.org/docs/GitHub.html#using-graphite-for-stacked-pull-requests (I've not used it myself). Also, if you are sending multiple dependent PRs to achieve a larger goal, could you create a GitHub issue to summarise and to communicate that? Otherwise it's hard to track what's going on (and to see the bigger picture). Specifically, I assume that these are related:
? If yes, what's the order and the dependencies? Could this be documented outside this PR? |
This is not as far as I am aware current policy? Going back several pages in the issues I see no such tracking issues. In addition, given that there's |
That wasn’t a reference to a formal policy — just a kind request. GitHub is admittedly not ideal for stacked PRs, but there are some lightweight practices that make a big difference. One common approach is to mention dependencies explicitly in the PR summary. For example:
I’m not sure I follow the connection. The issue volume is definitely a challenge, but that's all the more reason to find ways to improve clarity and coordination. Open to other suggestions too! |
I added a link to the discourse discussion. All PRs are listed there in order of dependency and how reviewers are expected to review them. I don't think a github issue has enough context than the discourse thread, which is why i prefered the discourse comment. Let's get back to technical discussion now. I'm happy to create a github issue or use graphite or anything that the reviewers prefer if it will make it easier for them to review it. For now, the discourse discussion link makes it clear. Please let me know if anything else is needed. Thanks for the review so far! |
Hi Kunwar, Thanks for updating the PR with the Discourse link — that definitely helps clarify the context and dependencies. And no worries about the format; I appreciate your flexibility on whether to use a GitHub issue, Graphite, or Discourse. Also, apologies — I missed the note below the summary referencing the other PR. That was my oversight; I scrolled past it too quickly. As you can see from the PRs I linked, contributors often mention dependencies directly in the PR summary. That’s where I was expecting to find it too, but clearly there’s no consistent practice yet. To help with standardization and reduce this kind of confusion in the future, I’ve submitted: Hopefully once that lands, it’ll make things clearer for both authors and reviewers. One additional note, directed more generally to the thread: some earlier phrasing (e.g., referring to a suggestion as “utterly useless”) felt unnecessarily dismissive. While I’m sure no harm was intended, I want to remind everyone that LLVM’s Code of Conduct encourages us to keep discussions professional and respectful, even when we disagree. That’s essential for keeping the project collaborative and welcoming for all contributors. Thanks again — looking forward to reviewing the next PRs in the stack! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nits, thanks!
e9b3338
to
2d7fa91
Compare
2d7fa91
to
81d517a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks!
This patch matches the definition of vector.scatter as a counter part of vector.gather.
All of the changes done in this patch make vector.scatter match vector.gather 's multi dimensional definition.
Unrolling for vector.scatter will be implemented in subsequent patches.
Discourse Discussion: https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/13