Skip to content

Commit e9b3338

Browse files
committed
[mlir][vector] Allow multi dim vectors in vector.scatter
1 parent bba9af2 commit e9b3338

File tree

6 files changed

+58
-19
lines changed

6 files changed

+58
-19
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2034,19 +2034,19 @@ def Vector_ScatterOp :
20342034
Vector_Op<"scatter">,
20352035
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
20362036
Variadic<Index>:$indices,
2037-
VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
2038-
VectorOfRankAndType<[1], [I1]>:$mask,
2039-
VectorOfRank<[1]>:$valueToStore)> {
2037+
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
2038+
VectorOfNonZeroRankOf<[I1]>:$mask,
2039+
AnyVectorOfNonZeroRank:$valueToStore)> {
20402040

20412041
let summary = [{
20422042
scatters elements from a vector into memory as defined by an index vector
20432043
and a mask vector
20442044
}];
20452045

20462046
let description = [{
2047-
The scatter operation stores elements from a 1-D vector into memory as
2048-
defined by a base with indices and an additional 1-D index vector, but
2049-
only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
2047+
The scatter operation stores elements from a n-D vector into memory as
2048+
defined by a base with indices and an additional n-D index vector, but
2049+
only if the corresponding bit in a n-D mask vector is set. Otherwise, no
20502050
action is taken for that element. Informally the semantics are:
20512051
```
20522052
if (mask[0]) base[index[0]] = value[0]

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,13 +313,16 @@ class VectorScatterOpConversion
313313
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
314314
return failure();
315315

316+
VectorType vType = scatter.getVectorType();
317+
if (vType.getRank() > 1)
318+
return failure();
319+
316320
// Resolve alignment.
317321
unsigned align;
318322
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
319323
return failure();
320324

321325
// Resolve address.
322-
VectorType vType = scatter.getVectorType();
323326
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
324327
adaptor.getIndices(), rewriter);
325328
Value ptrs =

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
53405340
return emitOpError("base and valueToStore element type should match");
53415341
if (llvm::size(getIndices()) != memType.getRank())
53425342
return emitOpError("requires ") << memType.getRank() << " indices";
5343-
if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5343+
if (valueVType.getShape() != indVType.getShape())
53445344
return emitOpError("expected valueToStore dim to match indices dim");
5345-
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5345+
if (valueVType.getShape() != maskVType.getShape())
53465346
return emitOpError("expected valueToStore dim to match mask dim");
53475347
return success();
53485348
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1654,7 +1654,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
16541654

16551655
//===----------------------------------------------------------------------===//
16561656
// vector.gather
1657-
//
1657+
//
16581658
// NOTE: vector.constant_mask won't lower with
16591659
// * --convert-to-llvm="filter-dialects=vector",
16601660
// hence testing here.
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
17191719

17201720
// -----
17211721

1722+
//===----------------------------------------------------------------------===//
1723+
// vector.scatter
1724+
//===----------------------------------------------------------------------===//
1725+
1726+
// Multi-Dimensional scatters are not supported yet. Check that we do not lower
1727+
// them.
1728+
1729+
func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
1730+
%0 = arith.constant 0: index
1731+
%1 = vector.constant_mask [2, 2] : vector<2x3xi1>
1732+
vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
1733+
return
1734+
}
1735+
1736+
// CHECK-LABEL: func @scatter_with_mask
1737+
// CHECK: vector.scatter
1738+
1739+
// -----
1740+
1741+
func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
1742+
%0 = arith.constant 0: index
1743+
// vector.constant_mask only supports 'none set' or 'all set' scalable
1744+
// dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
1745+
// width vectors above.
1746+
%1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
1747+
vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
1748+
return
1749+
}
1750+
1751+
// CHECK-LABEL: func @scatter_with_mask_scalable
1752+
// CHECK: vector.scatter
1753+
1754+
// -----
1755+
17221756
//===----------------------------------------------------------------------===//
17231757
// vector.interleave
17241758
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
14841484
func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
14851485
%mask: vector<16xi1>, %value: vector<2x16xf32>) {
14861486
%c0 = arith.constant 0 : index
1487-
// expected-error@+1 {{'vector.scatter' op operand #4 must be of ranks 1, but got 'vector<2x16xf32>'}}
1487+
// expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
14881488
vector.scatter %base[%c0][%indices], %mask, %value
14891489
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
14901490
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
882882
return
883883
}
884884

885+
// CHECK-LABEL: @gather_multi_dims
886+
func.func @gather_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
887+
%c0 = arith.constant 0 : index
888+
// CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
889+
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
890+
// CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
891+
vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
892+
return %0 : vector<2x16xf32>
893+
}
894+
885895
// CHECK-LABEL: @gather_on_tensor
886896
func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
887897
%c0 = arith.constant 0 : index
@@ -890,14 +900,6 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
890900
return %0 : vector<16xf32>
891901
}
892902

893-
// CHECK-LABEL: @gather_multi_dims
894-
func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
895-
%c0 = arith.constant 0 : index
896-
// CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
897-
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
898-
return %0 : vector<2x16xf32>
899-
}
900-
901903
// CHECK-LABEL: @expand_and_compress
902904
func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
903905
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)