Skip to content

[mlir][vector] LoadOp/StoreOp: Allow 0-D vectors #76134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1582,22 +1582,27 @@ def Vector_LoadOp : Vector_Op<"load"> {
vector. If the memref element type is vector, it should match the result
vector type.

Example 1: 1-D vector load on a scalar memref.
Example: 0-D vector load on a scalar memref.
```mlir
%result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<f32>
```

Example: 1-D vector load on a scalar memref.
```mlir
%result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>
```

Example 2: 1-D vector load on a vector memref.
Example: 1-D vector load on a vector memref.
```mlir
%result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
```

Example 3: 2-D vector load on a scalar memref.
Example: 2-D vector load on a scalar memref.
```mlir
%result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
```

Example 4: 2-D vector load on a vector memref.
Example: 2-D vector load on a vector memref.
```mlir
%result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
```
Expand All @@ -1608,12 +1613,12 @@ def Vector_LoadOp : Vector_Op<"load"> {
loaded out of bounds. Not all targets may support out-of-bounds vector
loads.

Example 5: Potential out-of-bound vector load.
Example: Potential out-of-bound vector load.
```mlir
%result = vector.load %memref[%index] : memref<?xf32>, vector<8xf32>
```

Example 6: Explicit out-of-bound vector load.
Example: Explicit out-of-bound vector load.
```mlir
%result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
```
Expand All @@ -1622,7 +1627,7 @@ def Vector_LoadOp : Vector_Op<"load"> {
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs AnyVector:$result);
let results = (outs AnyVectorOfAnyRank:$result);

let extraClassDeclaration = [{
MemRefType getMemRefType() {
Expand Down Expand Up @@ -1660,22 +1665,27 @@ def Vector_StoreOp : Vector_Op<"store"> {
to store. If the memref element type is vector, it should match the type
of the value to store.

Example 1: 1-D vector store on a scalar memref.
Example: 0-D vector store on a scalar memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
```

Example: 1-D vector store on a scalar memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
```

Example 2: 1-D vector store on a vector memref.
Example: 1-D vector store on a vector memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
```

Example 3: 2-D vector store on a scalar memref.
Example: 2-D vector store on a scalar memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
```

Example 4: 2-D vector store on a vector memref.
Example: 2-D vector store on a vector memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
```
Expand All @@ -1685,21 +1695,23 @@ def Vector_StoreOp : Vector_Op<"store"> {
target-specific. No assumptions should be made on the memory written out of
bounds. Not all targets may support out-of-bounds vector stores.

Example 5: Potential out-of-bounds vector store.
Example: Potential out-of-bounds vector store.
```mlir
vector.store %valueToStore, %memref[%index] : memref<?xf32>, vector<8xf32>
```

Example 6: Explicit out-of-bounds vector store.
Example: Explicit out-of-bounds vector store.
```mlir
vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
```
}];

let arguments = (ins AnyVector:$valueToStore,
let arguments = (ins
AnyVectorOfAnyRank:$valueToStore,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$base,
Variadic<Index>:$indices);
Variadic<Index>:$indices
);

let extraClassDeclaration = [{
MemRefType getMemRefType() {
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,36 @@ func.func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j

// -----

func.func @vector_load_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<f32> {
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
return %0 : vector<f32>
}

// CHECK-LABEL: func @vector_load_op_0d
// CHECK: %[[load:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}]
// CHECK: %[[vec:.*]] = llvm.mlir.undef : vector<1xf32>
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[inserted:.*]] = llvm.insertelement %[[load]], %[[vec]][%[[c0]] : i32] : vector<1xf32>
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[inserted]] : vector<1xf32> to vector<f32>
// CHECK: return %[[cast]] : vector<f32>

// -----

func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
%val = arith.constant dense<11.0> : vector<f32>
vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
return
}

// CHECK-LABEL: func @vector_store_op_0d
// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]

// -----

func.func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0: index
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,16 @@ func.func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
return %0 : vector<16xi32>
}

// CHECK-LABEL: @vector_load_and_store_0d_scalar_memref
func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
%i : index, %j : index) {
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to this change, but this reminded me the types are the wrong way round for vector.store

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was told that that's intentional (so that vector.load and vector.store look similar). Also not a fan, but that's tangential to this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was told that that's intentional (so that vector.load and vector.store look similar).

I see, it's a bit inconsistent w.r.t to transfer_read / transfer_write where the types order matches the inputs.

Also not a fan, but that's tangential to this PR.

Of course

return
}

// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
%i : index, %j : index) {
Expand Down