Skip to content

Commit def3d23

Browse files
committed
[mlir][vector] Add more tests for ConvertVectorToLLVM (4/n)
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass. Covers the following Ops: * vector.insertelement * vector.insert I have also renamed some function names from `@insert_element{}` to `@insertelement{}` - that's to make a clearer distinction between tests for `vector.insertelement` (tested by `@insertelement{}`) and `vector.insert` (tested by `@insert_element{}`).
1 parent 2c8bd4a commit def3d23

File tree

1 file changed

+91
-4
lines changed

1 file changed

+91
-4
lines changed

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,9 +1245,9 @@ func.func @extract_element_with_value_2d(%arg0: vector<1x16xf32>, %arg1: index)
12451245

12461246
// -----
12471247

1248-
// CHECK-LABEL: @insert_element_0d
1248+
// CHECK-LABEL: @insertelement_0d
12491249
// CHECK-SAME: %[[A:.*]]: f32,
1250-
func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
1250+
func.func @insertelement_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
12511251
// CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} :
12521252
// CHECK: vector<f32> to vector<1xf32>
12531253
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
@@ -1258,18 +1258,30 @@ func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
12581258

12591259
// -----
12601260

1261-
func.func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
1261+
func.func @insertelement(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
12621262
%0 = arith.constant 3 : i32
12631263
%1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32>
12641264
return %1 : vector<4xf32>
12651265
}
1266-
// CHECK-LABEL: @insert_element(
1266+
// CHECK-LABEL: @insertelement(
12671267
// CHECK-SAME: %[[A:.*]]: f32,
12681268
// CHECK-SAME: %[[B:.*]]: vector<4xf32>)
12691269
// CHECK: %[[c:.*]] = arith.constant 3 : i32
12701270
// CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[c]] : i32] : vector<4xf32>
12711271
// CHECK: return %[[x]] : vector<4xf32>
12721272

1273+
func.func @insertelement_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> {
1274+
%0 = arith.constant 3 : i32
1275+
%1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<[4]xf32>
1276+
return %1 : vector<[4]xf32>
1277+
}
1278+
// CHECK-LABEL: @insertelement_scalable(
1279+
// CHECK-SAME: %[[A:.*]]: f32,
1280+
// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>)
1281+
// CHECK: %[[c:.*]] = arith.constant 3 : i32
1282+
// CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[c]] : i32] : vector<[4]xf32>
1283+
// CHECK: return %[[x]] : vector<[4]xf32>
1284+
12731285
// -----
12741286

12751287
func.func @insert_element_index(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
@@ -1285,6 +1297,19 @@ func.func @insert_element_index(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf3
12851297
// CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[i]] : i64] : vector<4xf32>
12861298
// CHECK: return %[[x]] : vector<4xf32>
12871299

1300+
func.func @insertelement_index_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> {
1301+
%0 = arith.constant 3 : index
1302+
%1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<[4]xf32>
1303+
return %1 : vector<[4]xf32>
1304+
}
1305+
// CHECK-LABEL: @insertelement_index_scalable(
1306+
// CHECK-SAME: %[[A:.*]]: f32,
1307+
// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>)
1308+
// CHECK: %[[c:.*]] = arith.constant 3 : index
1309+
// CHECK: %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
1310+
// CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[i]] : i64] : vector<[4]xf32>
1311+
// CHECK: return %[[x]] : vector<[4]xf32>
1312+
12881313
// -----
12891314

12901315
func.func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
@@ -1296,6 +1321,15 @@ func.func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vecto
12961321
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
12971322
// CHECK: return {{.*}} : vector<4xf32>
12981323

1324+
func.func @insert_element_into_vec_1d_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> {
1325+
%0 = vector.insert %arg0, %arg1[3] : f32 into vector<[4]xf32>
1326+
return %0 : vector<[4]xf32>
1327+
}
1328+
// CHECK-LABEL: @insert_element_into_vec_1d_scalable
1329+
// CHECK: llvm.mlir.constant(3 : i64) : i64
1330+
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<[4]xf32>
1331+
// CHECK: return {{.*}} : vector<[4]xf32>
1332+
12991333
// -----
13001334

13011335
func.func @insert_index_element_into_vec_1d(%arg0: index, %arg1: vector<4xindex>) -> vector<4xindex> {
@@ -1312,6 +1346,21 @@ func.func @insert_index_element_into_vec_1d(%arg0: index, %arg1: vector<4xindex>
13121346
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : vector<4xi64> to vector<4xindex>
13131347
// CHECK: return %[[T5]] : vector<4xindex>
13141348

1349+
1350+
func.func @insert_index_element_into_vec_1d_scalable(%arg0: index, %arg1: vector<[4]xindex>) -> vector<[4]xindex> {
1351+
%0 = vector.insert %arg0, %arg1[3] : index into vector<[4]xindex>
1352+
return %0 : vector<[4]xindex>
1353+
}
1354+
// CHECK-LABEL: @insert_index_element_into_vec_1d_scalable(
1355+
// CHECK-SAME: %[[A:.*]]: index,
1356+
// CHECK-SAME: %[[B:.*]]: vector<[4]xindex>)
1357+
// CHECK-DAG: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64
1358+
// CHECK-DAG: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<[4]xindex> to vector<[4]xi64>
1359+
// CHECK: %[[T3:.*]] = llvm.mlir.constant(3 : i64) : i64
1360+
// CHECK: %[[T4:.*]] = llvm.insertelement %[[T0]], %[[T1]][%[[T3]] : i64] : vector<[4]xi64>
1361+
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : vector<[4]xi64> to vector<[4]xindex>
1362+
// CHECK: return %[[T5]] : vector<[4]xindex>
1363+
13151364
// -----
13161365

13171366
func.func @insert_vec_2d_into_vec_3d(%arg0: vector<8x16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
@@ -1322,6 +1371,14 @@ func.func @insert_vec_2d_into_vec_3d(%arg0: vector<8x16xf32>, %arg1: vector<4x8x
13221371
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm.array<4 x array<8 x vector<16xf32>>>
13231372
// CHECK: return {{.*}} : vector<4x8x16xf32>
13241373

1374+
func.func @insert_vec_2d_into_vec_3d_scalable(%arg0: vector<8x[16]xf32>, %arg1: vector<4x8x[16]xf32>) -> vector<4x8x[16]xf32> {
1375+
%0 = vector.insert %arg0, %arg1[3] : vector<8x[16]xf32> into vector<4x8x[16]xf32>
1376+
return %0 : vector<4x8x[16]xf32>
1377+
}
1378+
// CHECK-LABEL: @insert_vec_2d_into_vec_3d_scalable
1379+
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm.array<4 x array<8 x vector<[16]xf32>>>
1380+
// CHECK: return {{.*}} : vector<4x8x[16]xf32>
1381+
13251382
// -----
13261383

13271384
func.func @insert_vec_1d_into_vec_3d(%arg0: vector<16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
@@ -1332,6 +1389,14 @@ func.func @insert_vec_1d_into_vec_3d(%arg0: vector<16xf32>, %arg1: vector<4x8x16
13321389
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3, 7] : !llvm.array<4 x array<8 x vector<16xf32>>>
13331390
// CHECK: return {{.*}} : vector<4x8x16xf32>
13341391

1392+
func.func @insert_vec_1d_into_vec_3d_scalable(%arg0: vector<[16]xf32>, %arg1: vector<4x8x[16]xf32>) -> vector<4x8x[16]xf32> {
1393+
%0 = vector.insert %arg0, %arg1[3, 7] : vector<[16]xf32> into vector<4x8x[16]xf32>
1394+
return %0 : vector<4x8x[16]xf32>
1395+
}
1396+
// CHECK-LABEL: @insert_vec_1d_into_vec_3d_scalable
1397+
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3, 7] : !llvm.array<4 x array<8 x vector<[16]xf32>>>
1398+
// CHECK: return {{.*}} : vector<4x8x[16]xf32>
1399+
13351400
// -----
13361401

13371402
func.func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
@@ -1345,6 +1410,17 @@ func.func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) ->
13451410
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3, 7] : !llvm.array<4 x array<8 x vector<16xf32>>>
13461411
// CHECK: return {{.*}} : vector<4x8x16xf32>
13471412

1413+
func.func @insert_element_into_vec_3d_scalable(%arg0: f32, %arg1: vector<4x8x[16]xf32>) -> vector<4x8x[16]xf32> {
1414+
%0 = vector.insert %arg0, %arg1[3, 7, 15] : f32 into vector<4x8x[16]xf32>
1415+
return %0 : vector<4x8x[16]xf32>
1416+
}
1417+
// CHECK-LABEL: @insert_element_into_vec_3d_scalable
1418+
// CHECK: llvm.extractvalue {{.*}}[3, 7] : !llvm.array<4 x array<8 x vector<[16]xf32>>>
1419+
// CHECK: llvm.mlir.constant(15 : i64) : i64
1420+
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<[16]xf32>
1421+
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3, 7] : !llvm.array<4 x array<8 x vector<[16]xf32>>>
1422+
// CHECK: return {{.*}} : vector<4x8x[16]xf32>
1423+
13481424
// -----
13491425

13501426
func.func @insert_element_with_value_1d(%arg0: vector<16xf32>, %arg1: f32, %arg2: index)
@@ -1358,6 +1434,17 @@ func.func @insert_element_with_value_1d(%arg0: vector<16xf32>, %arg1: f32, %arg2
13581434
// CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
13591435
// CHECK: llvm.insertelement %[[SRC]], %[[DST]][%[[UC]] : i64] : vector<16xf32>
13601436

1437+
func.func @insert_element_with_value_1d_scalable(%arg0: vector<[16]xf32>, %arg1: f32, %arg2: index)
1438+
-> vector<[16]xf32> {
1439+
%0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<[16]xf32>
1440+
return %0 : vector<[16]xf32>
1441+
}
1442+
1443+
// CHECK-LABEL: @insert_element_with_value_1d_scalable
1444+
// CHECK-SAME: %[[DST:.+]]: vector<16xf32>, %[[SRC:.+]]: f32, %[[INDEX:.+]]: index
1445+
// CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
1446+
// CHECK: llvm.insertelement %[[SRC]], %[[DST]][%[[UC]] : i64] : vector<16xf32>
1447+
13611448
// -----
13621449

13631450
func.func @insert_element_with_value_2d(%base: vector<1x16xf32>, %value: f32, %idx: index)

0 commit comments

Comments
 (0)