Skip to content

[mlir][vector] Add more tests for ConvertVectorToLLVM (10/n) #117041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:

  • vector.maskedload,
  • vector.maskedstore,
  • vector.gather,
  • vector.scatter.

In addition:

  • For consistency with other tests, renamed test function names
    (e.g. @masked_load_op -> @masked_load_op)
  • Made some test names more descriptive, e.g @gather_op_2d ->
    @gather_1d_from_2d.

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:

  * `vector.maskedload`,
  * `vector.maskedstore`,
  * `vector.gather`,
  * `vector.scatter`.

In addition:
* For consistency with other tests, renamed test function names
  (e.g. `@masked_load_op` -> `@masked_load_op`)
* Made some test names more descriptive, e.g `@gather_op_2d` ->
  `@gather_1d_from_2d`.
@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:

  • vector.maskedload,
  • vector.maskedstore,
  • vector.gather,
  • vector.scatter.

In addition:

  • For consistency with other tests, renamed test function names
    (e.g. @<!-- -->masked_load_op -> @<!-- -->masked_load_op)
  • Made some test names more descriptive, e.g @<!-- -->gather_op_2d ->
    @<!-- -->gather_1d_from_2d.

Patch is 24.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117041.diff

1 Files Affected:

  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+202-34)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 529dd4094507fa..da0222bc942376 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -3042,13 +3042,13 @@ func.func @vector_store_index_scalable(%memref : memref<200x100xindex>, %i : ind
 
 // -----
 
-func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
+func.func @vector_store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
   %val = arith.constant dense<11.0> : vector<f32>
   vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
   return
 }
 
-// CHECK-LABEL: func @vector_store_op_0d
+// CHECK-LABEL: func @vector_store_0d
 // CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
 // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
 // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
@@ -3057,13 +3057,13 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
 
 // -----
 
-func.func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
+func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0: index
   %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %0 : vector<16xf32>
 }
 
-// CHECK-LABEL: func @masked_load_op
+// CHECK-LABEL: func @masked_load
 // CHECK: %[[CO:.*]] = arith.constant 0 : index
 // CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
 // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -3072,23 +3072,48 @@ func.func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vec
 
 // -----
 
-func.func @masked_load_op_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> {
+func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) -> vector<[16]xf32> {
+  %c0 = arith.constant 0: index
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
+  return %0 : vector<[16]xf32>
+}
+
+// CHECK-LABEL: func @masked_load_scalable
+// CHECK: %[[CO:.*]] = arith.constant 0 : index
+// CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<[16]xi1>, vector<[16]xf32>) -> vector<[16]xf32>
+// CHECK: return %[[L]] : vector<[16]xf32>
+
+// -----
+
+func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> {
   %c0 = arith.constant 0: index
   %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
   return %0 : vector<16xindex>
 }
-// CHECK-LABEL: func @masked_load_op_index
+// CHECK-LABEL: func @masked_load_index
 // CHECK: %{{.*}} = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xi64>) -> vector<16xi64>
 
 // -----
 
-func.func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
+func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) -> vector<[16]xindex> {
+  %c0 = arith.constant 0: index
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex> into vector<[16]xindex>
+  return %0 : vector<[16]xindex>
+}
+// CHECK-LABEL: func @masked_load_index_scalable
+// CHECK: %{{.*}} = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<[16]xi1>, vector<[16]xi64>) -> vector<[16]xi64>
+
+// -----
+
+func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
   %c0 = arith.constant 0: index
   vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
   return
 }
 
-// CHECK-LABEL: func @masked_store_op
+// CHECK-LABEL: func @masked_store
 // CHECK: %[[CO:.*]] = arith.constant 0 : index
 // CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
 // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -3096,76 +3121,126 @@ func.func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: ve
 
 // -----
 
-func.func @masked_store_op_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) {
+func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) {
+  %c0 = arith.constant 0: index
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
+  return
+}
+
+// CHECK-LABEL: func @masked_store_scalable
+// CHECK: %[[CO:.*]] = arith.constant 0 : index
+// CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[16]xf32>, vector<[16]xi1> into !llvm.ptr
+
+// -----
+
+func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) {
   %c0 = arith.constant 0: index
   vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex>
   return
 }
-// CHECK-LABEL: func @masked_store_op_index
+// CHECK-LABEL: func @masked_store_index
 // CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<16xi64>, vector<16xi1> into !llvm.ptr
 
 // -----
 
-func.func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+func.func @masked_store_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) {
+  %c0 = arith.constant 0: index
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex>
+  return
+}
+// CHECK-LABEL: func @masked_store_index_scalable
+// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<[16]xi64>, vector<[16]xi1> into !llvm.ptr
+
+// -----
+
+func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
   %0 = arith.constant 0: index
   %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
   return %1 : vector<3xf32>
 }
 
-// CHECK-LABEL: func @gather_op
+// CHECK-LABEL: func @gather
 // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // CHECK: return %[[G]] : vector<3xf32>
 
 // -----
 
-func.func @gather_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) -> vector<[3]xf32> {
+func.func @gather_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) -> vector<[3]xf32> {
   %0 = arith.constant 0: index
   %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
   return %1 : vector<[3]xf32>
 }
 
-// CHECK-LABEL: func @gather_op_scalable
+// CHECK-LABEL: func @gather_scalable
 // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
 // CHECK: return %[[G]] : vector<[3]xf32>
 
 // -----
 
-func.func @gather_op_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+func.func @gather_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
   %0 = arith.constant 0: index
   %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
   return %1 : vector<3xf32>
 }
 
-// CHECK-LABEL: func @gather_op
+// CHECK-LABEL: func @gather_global_memory
 // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // CHECK: return %[[G]] : vector<3xf32>
 
 // -----
 
+func.func @gather_global_memory_scalable(%arg0: memref<?xf32, 1>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) -> vector<[3]xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
+  return %1 : vector<[3]xf32>
+}
+
+// CHECK-LABEL: func @gather_global_memory_scalable
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: return %[[G]] : vector<[3]xf32>
+
+// -----
+
 
-func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> {
+func.func @gather_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> {
   %0 = arith.constant 0: index
   %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xindex>, vector<3xi1>, vector<3xindex> into vector<3xindex>
   return %1 : vector<3xindex>
 }
 
-// CHECK-LABEL: func @gather_op_index
+// CHECK-LABEL: func @gather_index
 // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
 // CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<3xi64> to vector<3xindex>
 
 // -----
 
-func.func @gather_op_multi_dims(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
+func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xindex>) -> vector<[3]xindex> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xindex> into vector<[3]xindex>
+  return %1 : vector<[3]xindex>
+}
+
+// CHECK-LABEL: func @gather_index_scalable
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
+// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<[3]xi64> to vector<[3]xindex>
+
+// -----
+
+func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
   %0 = arith.constant 0: index
   %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
   return %1 : vector<2x3xf32>
 }
 
-// CHECK-LABEL: func @gather_op_multi_dims
+// CHECK-LABEL: func @gather_2d_from_1d
 // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
 // CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
 // CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
@@ -3182,40 +3257,94 @@ func.func @gather_op_multi_dims(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %a
 
 // -----
 
-func.func @gather_op_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
+  return %1 : vector<2x[3]xf32>
+}
+
+// CHECK-LABEL: func @gather_2d_from_1d_scalable
+// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
+// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
+// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
+// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
+// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
+// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
+// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
+// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
+
+// -----
+
+func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
   %0 = arith.constant 0: index
   %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
   return %2 : vector<2x3xf32>
 }
 
-// CHECK-LABEL: func @gather_op_with_mask
+// CHECK-LABEL: func @gather_with_mask
 // CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 
 // -----
 
-func.func @gather_op_with_zero_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+  %0 = arith.constant 0: index
+  // vector.constant_mask only supports 'none set' or 'all set' scalable
+  // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+  // width vectors above.
+  %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
+  %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
+  return %2 : vector<2x[3]xf32>
+}
+
+// CHECK-LABEL: func @gather_with_mask_scalable
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+
+
+// -----
+
+func.func @gather_with_zero_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
   %0 = arith.constant 0: index
   %1 = vector.constant_mask [0, 0] : vector<2x3xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
   return %2 : vector<2x3xf32>
 }
 
-// CHECK-LABEL: func @gather_op_with_zero_mask
+// CHECK-LABEL: func @gather_with_zero_mask
 // CHECK-SAME:    (%{{.*}}: memref<?xf32>, %{{.*}}: vector<2x3xi32>, %[[S:.*]]: vector<2x3xf32>)
 // CHECK-NOT:   %{{.*}} = llvm.intr.masked.gather
 // CHECK:       return %[[S]] : vector<2x3xf32>
 
 // -----
 
-func.func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
+func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [0, 0] : vector<2x[3]xi1>
+  %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
+  return %2 : vector<2x[3]xf32>
+}
+
+// CHECK-LABEL: func @gather_with_zero_mask_scalable
+// CHECK-SAME:    (%{{.*}}: memref<?xf32>, %{{.*}}: vector<2x[3]xi32>, %[[S:.*]]: vector<2x[3]xf32>)
+// CHECK-NOT:   %{{.*}} = llvm.intr.masked.gather
+// CHECK:       return %[[S]] : vector<2x[3]xf32>
+
+// -----
+
+func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
   %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
   return %1 : vector<4xf32>
 }
 
-// CHECK-LABEL: func @gather_2d_op
+// CHECK-LABEL: func @gather_1d_from_2d
 // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
 // CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
@@ -3223,55 +3352,94 @@ func.func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vec
 
 // -----
 
-func.func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]xi32>, %arg2: vector<[4]xi1>, %arg3: vector<[4]xf32>) -> vector<[4]xf32> {
+  %0 = arith.constant 3 : index
+  %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x?xf32>, vector<[4]xi32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+  return %1 : vector<[4]xf32>
+}
+
+// CHECK-LABEL: func @gather_1d_from_2d_scalable
+// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 4 x ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: return %[[G]] : vector<[4]xf32>
+
+// -----
+
+func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
   %0 = arith.constant 0: index
   vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
   return
 }
 
-// CHECK-LABEL: func @scatter_op
+// CHECK-LABEL: func @scatter
 // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
 // CHECK: llvm.intr.maske...
[truncated]

Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 NIT but I realised there were other occurences in the file. I won't mind to let it be.
LGTM.

@@ -3042,13 +3042,13 @@ func.func @vector_store_index_scalable(%memref : memref<200x100xindex>, %i : ind

// -----

func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
func.func @vector_store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT or part of another MR, this test has args like %memref whereas other seem to follow argX type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out!

There's actually more inconsistencies in this file, e.g.

  • @vector_store_0d vs @scatter_index (i.e. @vector_{op_name}_{variant} vs @{op_name}_{variant}.

I was planning to send a dedicated refactor patch at the end :)

@banach-space banach-space merged commit 4086ead into llvm:main Nov 21, 2024
10 checks passed
@banach-space banach-space deleted the andrzej/extend_vector_to_llvm_test_11 branch November 21, 2024 09:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants