Skip to content

[mlir][vector] Add more tests for ConvertVectorToLLVM (3/n) #102854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:

  • vector.extractelement
  • vector.extract

I have also renamed some function names from @extract_element{} to
@extractelement{} - that's to make a clearer distinction between
tests for vector.extractelement (tested by @extractelement{}) and
vector.extract (tested by @extract_element{}).

@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:

  • vector.extractelement
  • vector.extract

I have also renamed some function names from @<!-- -->extract_element{} to
@<!-- -->extractelement{} - that's to make a clearer distinction between
tests for vector.extractelement (tested by @<!-- -->extractelement{}) and
vector.extract (tested by @<!-- -->extract_element{}).


Full diff: https://github.com/llvm/llvm-project/pull/102854.diff

1 Files Affected:

  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+85-6)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d164e875097968..9b61c4493994c2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1049,8 +1049,8 @@ func.func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf3
 
 // -----
 
-// CHECK-LABEL: @extract_element_0d
-func.func @extract_element_0d(%a: vector<f32>) -> f32 {
+// CHECK-LABEL: @extractelement_0d
+func.func @extractelement_0d(%a: vector<f32>) -> f32 {
   // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
   // CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32>
   %1 = vector.extractelement %a[] : vector<f32>
@@ -1059,31 +1059,54 @@ func.func @extract_element_0d(%a: vector<f32>) -> f32 {
 
 // -----
 
-func.func @extract_element(%arg0: vector<16xf32>) -> f32 {
+func.func @extractelement(%arg0: vector<16xf32>) -> f32 {
   %0 = arith.constant 15 : i32
   %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>
   return %1 : f32
 }
-// CHECK-LABEL: @extract_element(
+// CHECK-LABEL: @extractelement(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>)
 //       CHECK:   %[[c:.*]] = arith.constant 15 : i32
 //       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : i32] : vector<16xf32>
 //       CHECK:   return %[[x]] : f32
 
+func.func @extractelement_scalable(%arg0: vector<[16]xf32>) -> f32 {
+  %0 = arith.constant 15 : i32
+  %1 = vector.extractelement %arg0[%0 : i32]: vector<[16]xf32>
+  return %1 : f32
+}
+// CHECK-LABEL: @extractelement_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
+//       CHECK:   %[[c:.*]] = arith.constant 15 : i32
+//       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : i32] : vector<[16]xf32>
+//       CHECK:   return %[[x]] : f32
+
 // -----
 
-func.func @extract_element_index(%arg0: vector<16xf32>) -> f32 {
+func.func @extractelement_index(%arg0: vector<16xf32>) -> f32 {
   %0 = arith.constant 15 : index
   %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32>
   return %1 : f32
 }
-// CHECK-LABEL: @extract_element_index(
+// CHECK-LABEL: @extractelement_index(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>)
 //       CHECK:   %[[c:.*]] = arith.constant 15 : index
 //       CHECK:   %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
 //       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[i]] : i64] : vector<16xf32>
 //       CHECK:   return %[[x]] : f32
 
+func.func @extractelement_index_scalable(%arg0: vector<[16]xf32>) -> f32 {
+  %0 = arith.constant 15 : index
+  %1 = vector.extractelement %arg0[%0 : index]: vector<[16]xf32>
+  return %1 : f32
+}
+// CHECK-LABEL: @extractelement_index_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
+//       CHECK:   %[[c:.*]] = arith.constant 15 : index
+//       CHECK:   %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
+//       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[i]] : i64] : vector<[16]xf32>
+//       CHECK:   return %[[x]] : f32
+
 // -----
 
 func.func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
@@ -1095,6 +1118,15 @@ func.func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
 //       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32>
 //       CHECK:   return {{.*}} : f32
 
+func.func @extract_element_from_vec_1d_scalable(%arg0: vector<[16]xf32>) -> f32 {
+  %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_element_from_vec_1d_scalable
+//       CHECK:   llvm.mlir.constant(15 : i64) : i64
+//       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i64] : vector<[16]xf32>
+//       CHECK:   return {{.*}} : f32
+
 // -----
 
 func.func @extract_index_element_from_vec_1d(%arg0: vector<16xindex>) -> index {
@@ -1109,6 +1141,18 @@ func.func @extract_index_element_from_vec_1d(%arg0: vector<16xindex>) -> index {
 //       CHECK:   %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
 //       CHECK:   return %[[T3]] : index
 
+func.func @extract_index_element_from_vec_1d_scalable(%arg0: vector<[16]xindex>) -> index {
+  %0 = vector.extract %arg0[15]: index from vector<[16]xindex>
+  return %0 : index
+}
+// CHECK-LABEL: @extract_index_element_from_vec_1d_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xindex>)
+//       CHECK:   %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[16]xindex> to vector<[16]xi64>
+//       CHECK:   %[[T1:.*]] = llvm.mlir.constant(15 : i64) : i64
+//       CHECK:   %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<[16]xi64>
+//       CHECK:   %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
+//       CHECK:   return %[[T3]] : index
+
 // -----
 
 func.func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
@@ -1119,6 +1163,14 @@ func.func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16x
 //       CHECK:   llvm.extractvalue {{.*}}[0] : !llvm.array<4 x array<3 x vector<16xf32>>>
 //       CHECK:   return {{.*}} : vector<3x16xf32>
 
+func.func @extract_vec_2d_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<3x[16]xf32> {
+  %0 = vector.extract %arg0[0]: vector<3x[16]xf32> from vector<4x3x[16]xf32>
+  return %0 : vector<3x[16]xf32>
+}
+// CHECK-LABEL: @extract_vec_2d_from_vec_3d_scalable
+//       CHECK:   llvm.extractvalue {{.*}}[0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
+//       CHECK:   return {{.*}} : vector<3x[16]xf32>
+
 // -----
 
 func.func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf32> {
@@ -1129,6 +1181,14 @@ func.func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf3
 //       CHECK:   llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<16xf32>>>
 //       CHECK:   return {{.*}} : vector<16xf32>
 
+func.func @extract_vec_1d_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<[16]xf32> {
+  %0 = vector.extract %arg0[0, 0]: vector<[16]xf32> from vector<4x3x[16]xf32>
+  return %0 : vector<[16]xf32>
+}
+// CHECK-LABEL: @extract_vec_1d_from_vec_3d_scalable
+//       CHECK:   llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
+//       CHECK:   return {{.*}} : vector<[16]xf32>
+
 // -----
 
 func.func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
@@ -1141,6 +1201,16 @@ func.func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
 //       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32>
 //       CHECK:   return {{.*}} : f32
 
+func.func @extract_element_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> f32 {
+  %0 = vector.extract %arg0[0, 0, 0]: f32 from vector<4x3x[16]xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_element_from_vec_3d_scalable
+//       CHECK:   llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
+//       CHECK:   llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i64] : vector<[16]xf32>
+//       CHECK:   return {{.*}} : f32
+
 // -----
 
 func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) -> f32 {
@@ -1152,6 +1222,15 @@ func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) ->
 //       CHECK:   %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
 //       CHECK:   llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32>
 
+func.func @extract_element_with_value_1d_scalable(%arg0: vector<[16]xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1]: f32 from vector<[16]xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_element_with_value_1d_scalable
+//  CHECK-SAME:   %[[VEC:.+]]: vector<[16]xf32>, %[[INDEX:.+]]: index
+//       CHECK:   %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
+//       CHECK:   llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<[16]xf32>
+
 // -----
 
 func.func @extract_element_with_value_2d(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
Optional: Rebase to ensure Linux pre-commit passes (current failure seems unrelated)

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
  * vector.extractelement
  * vector.extract

I have also renamed some function names from `@extract_element{}` to
`@extractelement{}` - that's to make a clearer distinction between
tests for `vector.extractelement` (tested by `@extractelement{}`) and
`vector.extract` (tested by `@extract_element{}`).
@banach-space banach-space force-pushed the andrzej/extend_vector_to_llvm_test_3 branch from 6588586 to c909c00 Compare August 13, 2024 10:14
@banach-space banach-space merged commit f0f5afe into llvm:main Aug 13, 2024
8 checks passed
@banach-space banach-space deleted the andrzej/extend_vector_to_llvm_test_3 branch August 13, 2024 17:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants