Skip to content

Commit 24a8e18

Browse files
authored
[mlir][vector] Allow multi dim vectors in vector.scatter (#132217)
This patch matches the definition of vector.scatter as a counter part of vector.gather. All of the changes done in this patch make vector.scatter match vector.gather 's multi dimensional definition. Unrolling for vector.scatter will be implemented in subsequent patches. Discourse Discussion: https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/13
1 parent e60fe2e commit 24a8e18

File tree

6 files changed

+74
-28
lines changed

6 files changed

+74
-28
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2034,19 +2034,19 @@ def Vector_ScatterOp :
20342034
Vector_Op<"scatter">,
20352035
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
20362036
Variadic<Index>:$indices,
2037-
VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
2038-
VectorOfRankAndType<[1], [I1]>:$mask,
2039-
VectorOfRank<[1]>:$valueToStore)> {
2037+
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
2038+
VectorOfNonZeroRankOf<[I1]>:$mask,
2039+
AnyVectorOfNonZeroRank:$valueToStore)> {
20402040

20412041
let summary = [{
20422042
scatters elements from a vector into memory as defined by an index vector
20432043
and a mask vector
20442044
}];
20452045

20462046
let description = [{
2047-
The scatter operation stores elements from a 1-D vector into memory as
2048-
defined by a base with indices and an additional 1-D index vector, but
2049-
only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
2047+
The scatter operation stores elements from a n-D vector into memory as
2048+
defined by a base with indices and an additional n-D index vector, but
2049+
only if the corresponding bit in a n-D mask vector is set. Otherwise, no
20502050
action is taken for that element. Informally the semantics are:
20512051
```
20522052
if (mask[0]) base[index[0]] = value[0]

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -263,22 +263,25 @@ class VectorGatherOpConversion
263263
LogicalResult
264264
matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
265265
ConversionPatternRewriter &rewriter) const override {
266+
Location loc = gather->getLoc();
266267
MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
267268
assert(memRefType && "The base should be bufferized");
268269

269270
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
270-
return failure();
271+
return rewriter.notifyMatchFailure(gather, "memref type not supported");
271272

272273
VectorType vType = gather.getVectorType();
273-
if (vType.getRank() > 1)
274-
return failure();
275-
276-
Location loc = gather->getLoc();
274+
if (vType.getRank() > 1) {
275+
return rewriter.notifyMatchFailure(
276+
gather, "only 1-D vectors can be lowered to LLVM");
277+
}
277278

278279
// Resolve alignment.
279280
unsigned align;
280-
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
281-
return failure();
281+
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
282+
return rewriter.notifyMatchFailure(gather,
283+
"could not resolve memref alignment");
284+
}
282285

283286
// Resolve address.
284287
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -309,15 +312,22 @@ class VectorScatterOpConversion
309312
MemRefType memRefType = scatter.getMemRefType();
310313

311314
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
312-
return failure();
315+
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
316+
317+
VectorType vType = scatter.getVectorType();
318+
if (vType.getRank() > 1) {
319+
return rewriter.notifyMatchFailure(
320+
scatter, "only 1-D vectors can be lowered to LLVM");
321+
}
313322

314323
// Resolve alignment.
315324
unsigned align;
316-
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
317-
return failure();
325+
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
326+
return rewriter.notifyMatchFailure(scatter,
327+
"could not resolve memref alignment");
328+
}
318329

319330
// Resolve address.
320-
VectorType vType = scatter.getVectorType();
321331
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
322332
adaptor.getIndices(), rewriter);
323333
Value ptrs =

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
53405340
return emitOpError("base and valueToStore element type should match");
53415341
if (llvm::size(getIndices()) != memType.getRank())
53425342
return emitOpError("requires ") << memType.getRank() << " indices";
5343-
if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5343+
if (valueVType.getShape() != indVType.getShape())
53445344
return emitOpError("expected valueToStore dim to match indices dim");
5345-
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5345+
if (valueVType.getShape() != maskVType.getShape())
53465346
return emitOpError("expected valueToStore dim to match mask dim");
53475347
return success();
53485348
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
17191719

17201720
// -----
17211721

1722+
//===----------------------------------------------------------------------===//
1723+
// vector.scatter
1724+
//===----------------------------------------------------------------------===//
1725+
1726+
// Multi-Dimensional scatters are not supported yet. Check that we do not lower
1727+
// them.
1728+
1729+
func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
1730+
%0 = arith.constant 0: index
1731+
%1 = vector.constant_mask [2, 2] : vector<2x3xi1>
1732+
vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
1733+
return
1734+
}
1735+
1736+
// CHECK-LABEL: func @scatter_with_mask
1737+
// CHECK: vector.scatter
1738+
1739+
// -----
1740+
1741+
func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
1742+
%0 = arith.constant 0: index
1743+
// vector.constant_mask only supports 'none set' or 'all set' scalable
1744+
// dimensions, hence [2, 3] rather than [2, 2] as in the example for fixed
1745+
// width vectors above.
1746+
%1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
1747+
vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
1748+
return
1749+
}
1750+
1751+
// CHECK-LABEL: func @scatter_with_mask_scalable
1752+
// CHECK: vector.scatter
1753+
1754+
// -----
1755+
17221756
//===----------------------------------------------------------------------===//
17231757
// vector.interleave
17241758
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
14841484
func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
14851485
%mask: vector<16xi1>, %value: vector<2x16xf32>) {
14861486
%c0 = arith.constant 0 : index
1487-
// expected-error@+1 {{'vector.scatter' op operand #4 must be of ranks 1, but got 'vector<2x16xf32>'}}
1487+
// expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
14881488
vector.scatter %base[%c0][%indices], %mask, %value
14891489
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
14901490
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
882882
return
883883
}
884884

885+
// CHECK-LABEL: @gather_and_scatter_multi_dims
886+
func.func @gather_and_scatter_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
887+
%c0 = arith.constant 0 : index
888+
// CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
889+
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
890+
// CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
891+
vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
892+
return %0 : vector<2x16xf32>
893+
}
894+
885895
// CHECK-LABEL: @gather_on_tensor
886896
func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
887897
%c0 = arith.constant 0 : index
@@ -890,14 +900,6 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
890900
return %0 : vector<16xf32>
891901
}
892902

893-
// CHECK-LABEL: @gather_multi_dims
894-
func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
895-
%c0 = arith.constant 0 : index
896-
// CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
897-
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
898-
return %0 : vector<2x16xf32>
899-
}
900-
901903
// CHECK-LABEL: @expand_and_compress
902904
func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
903905
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)