-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][vector] Add more tests for ConvertVectorToLLVM (3/n) #102854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add more tests for ConvertVectorToLLVM (3/n) #102854
Conversation
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesAdds tests with scalable vectors for the Vector-To-LLVM conversion pass.
I have also renamed some function names from Full diff: https://github.com/llvm/llvm-project/pull/102854.diff 1 Files Affected:
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d164e875097968..9b61c4493994c2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1049,8 +1049,8 @@ func.func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf3
// -----
-// CHECK-LABEL: @extract_element_0d
-func.func @extract_element_0d(%a: vector<f32>) -> f32 {
+// CHECK-LABEL: @extractelement_0d
+func.func @extractelement_0d(%a: vector<f32>) -> f32 {
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32>
%1 = vector.extractelement %a[] : vector<f32>
@@ -1059,31 +1059,54 @@ func.func @extract_element_0d(%a: vector<f32>) -> f32 {
// -----
-func.func @extract_element(%arg0: vector<16xf32>) -> f32 {
+func.func @extractelement(%arg0: vector<16xf32>) -> f32 {
%0 = arith.constant 15 : i32
%1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>
return %1 : f32
}
-// CHECK-LABEL: @extract_element(
+// CHECK-LABEL: @extractelement(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
// CHECK: %[[c:.*]] = arith.constant 15 : i32
// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : i32] : vector<16xf32>
// CHECK: return %[[x]] : f32
+func.func @extractelement_scalable(%arg0: vector<[16]xf32>) -> f32 {
+ %0 = arith.constant 15 : i32
+ %1 = vector.extractelement %arg0[%0 : i32]: vector<[16]xf32>
+ return %1 : f32
+}
+// CHECK-LABEL: @extractelement_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
+// CHECK: %[[c:.*]] = arith.constant 15 : i32
+// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : i32] : vector<[16]xf32>
+// CHECK: return %[[x]] : f32
+
// -----
-func.func @extract_element_index(%arg0: vector<16xf32>) -> f32 {
+func.func @extractelement_index(%arg0: vector<16xf32>) -> f32 {
%0 = arith.constant 15 : index
%1 = vector.extractelement %arg0[%0 : index]: vector<16xf32>
return %1 : f32
}
-// CHECK-LABEL: @extract_element_index(
+// CHECK-LABEL: @extractelement_index(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
// CHECK: %[[c:.*]] = arith.constant 15 : index
// CHECK: %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[i]] : i64] : vector<16xf32>
// CHECK: return %[[x]] : f32
+func.func @extractelement_index_scalable(%arg0: vector<[16]xf32>) -> f32 {
+ %0 = arith.constant 15 : index
+ %1 = vector.extractelement %arg0[%0 : index]: vector<[16]xf32>
+ return %1 : f32
+}
+// CHECK-LABEL: @extractelement_index_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
+// CHECK: %[[c:.*]] = arith.constant 15 : index
+// CHECK: %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
+// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[i]] : i64] : vector<[16]xf32>
+// CHECK: return %[[x]] : f32
+
// -----
func.func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
@@ -1095,6 +1118,15 @@ func.func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
// CHECK: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32>
// CHECK: return {{.*}} : f32
+func.func @extract_element_from_vec_1d_scalable(%arg0: vector<[16]xf32>) -> f32 {
+ %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
+ return %0 : f32
+}
+// CHECK-LABEL: @extract_element_from_vec_1d_scalable
+// CHECK: llvm.mlir.constant(15 : i64) : i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<[16]xf32>
+// CHECK: return {{.*}} : f32
+
// -----
func.func @extract_index_element_from_vec_1d(%arg0: vector<16xindex>) -> index {
@@ -1109,6 +1141,18 @@ func.func @extract_index_element_from_vec_1d(%arg0: vector<16xindex>) -> index {
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
// CHECK: return %[[T3]] : index
+func.func @extract_index_element_from_vec_1d_scalable(%arg0: vector<[16]xindex>) -> index {
+ %0 = vector.extract %arg0[15]: index from vector<[16]xindex>
+ return %0 : index
+}
+// CHECK-LABEL: @extract_index_element_from_vec_1d_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xindex>)
+// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[16]xindex> to vector<[16]xi64>
+// CHECK: %[[T1:.*]] = llvm.mlir.constant(15 : i64) : i64
+// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<[16]xi64>
+// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
+// CHECK: return %[[T3]] : index
+
// -----
func.func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
@@ -1119,6 +1163,14 @@ func.func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16x
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm.array<4 x array<3 x vector<16xf32>>>
// CHECK: return {{.*}} : vector<3x16xf32>
+func.func @extract_vec_2d_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<3x[16]xf32> {
+ %0 = vector.extract %arg0[0]: vector<3x[16]xf32> from vector<4x3x[16]xf32>
+ return %0 : vector<3x[16]xf32>
+}
+// CHECK-LABEL: @extract_vec_2d_from_vec_3d_scalable
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
+// CHECK: return {{.*}} : vector<3x[16]xf32>
+
// -----
func.func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf32> {
@@ -1129,6 +1181,14 @@ func.func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf3
// CHECK: llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<16xf32>>>
// CHECK: return {{.*}} : vector<16xf32>
+func.func @extract_vec_1d_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<[16]xf32> {
+ %0 = vector.extract %arg0[0, 0]: vector<[16]xf32> from vector<4x3x[16]xf32>
+ return %0 : vector<[16]xf32>
+}
+// CHECK-LABEL: @extract_vec_1d_from_vec_3d_scalable
+// CHECK: llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
+// CHECK: return {{.*}} : vector<[16]xf32>
+
// -----
func.func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
@@ -1141,6 +1201,16 @@ func.func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
// CHECK: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32>
// CHECK: return {{.*}} : f32
+func.func @extract_element_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> f32 {
+ %0 = vector.extract %arg0[0, 0, 0]: f32 from vector<4x3x[16]xf32>
+ return %0 : f32
+}
+// CHECK-LABEL: @extract_element_from_vec_3d_scalable
+// CHECK: llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
+// CHECK: llvm.mlir.constant(0 : i64) : i64
+// CHECK: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<[16]xf32>
+// CHECK: return {{.*}} : f32
+
// -----
func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) -> f32 {
@@ -1152,6 +1222,15 @@ func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) ->
// CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
// CHECK: llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32>
+func.func @extract_element_with_value_1d_scalable(%arg0: vector<[16]xf32>, %arg1: index) -> f32 {
+ %0 = vector.extract %arg0[%arg1]: f32 from vector<[16]xf32>
+ return %0 : f32
+}
+// CHECK-LABEL: @extract_element_with_value_1d_scalable
+// CHECK-SAME: %[[VEC:.+]]: vector<[16]xf32>, %[[INDEX:.+]]: index
+// CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
+// CHECK: llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<[16]xf32>
+
// -----
func.func @extract_element_with_value_2d(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Optional: Rebase to ensure Linux pre-commit passes (current failure seems unrelated)
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass. Covers the following Ops: * vector.extractelement * vector.extract I have also renamed some function names from `@extract_element{}` to `@extractelement{}` - that's to make a clearer distinction between tests for `vector.extractelement` (tested by `@extractelement{}`) and `vector.extract` (tested by `@extract_element{}`).
6588586
to
c909c00
Compare
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
I have also renamed some function names from
@extract_element{}
to@extractelement{}
- that's to make a clearer distinction betweentests for
vector.extractelement
(tested by@extractelement{}
) andvector.extract
(tested by@extract_element{}
).