Skip to content

Commit b260218

Browse files
committed
[mlir][vector] Add more tests for ConvertVectorToLLVM (8/n)
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass. Covers the following Ops: * vector.transfer_read * vector.transfer_write In addition: * Tests that test both xfer_read and xfer_write have their names updated to capture that (e.g. `@transfer_read_1d_mask` -> `@transfer_read_write_1d_mask`) * `@transfer_write_1d_scalable_mask` and `@transfer_read_1d_scalable_mask` are re-written as `@transfer_read_write_1d_mask_scalable`. This is to make it clear that this case is meant to complement `@transfer_read_write_1d_mask`. * `@transfer_write_tensor` is updated to also test `xfer_read`.
1 parent c2063de commit b260218

File tree

1 file changed

+191
-32
lines changed

1 file changed

+191
-32
lines changed

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 191 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2377,7 +2377,7 @@ func.func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vecto
23772377

23782378
// -----
23792379

2380-
func.func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
2380+
func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
23812381
%f7 = arith.constant 7.0: f32
23822382
%f = vector.transfer_read %A[%base], %f7
23832383
{permutation_map = affine_map<(d0) -> (d0)>} :
@@ -2387,7 +2387,7 @@ func.func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32>
23872387
vector<17xf32>, memref<?xf32>
23882388
return %f: vector<17xf32>
23892389
}
2390-
// CHECK-LABEL: func @transfer_read_1d
2390+
// CHECK-LABEL: func @transfer_read_write_1d
23912391
// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
23922392
// CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32>
23932393
// CHECK: %[[C7:.*]] = arith.constant 7.0
@@ -2449,9 +2449,77 @@ func.func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32>
24492449
// CHECK-SAME: {alignment = 4 : i32} :
24502450
// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr
24512451

2452+
func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) -> vector<[17]xf32> {
2453+
%f7 = arith.constant 7.0: f32
2454+
%f = vector.transfer_read %A[%base], %f7
2455+
{permutation_map = affine_map<(d0) -> (d0)>} :
2456+
memref<?xf32>, vector<[17]xf32>
2457+
vector.transfer_write %f, %A[%base]
2458+
{permutation_map = affine_map<(d0) -> (d0)>} :
2459+
vector<[17]xf32>, memref<?xf32>
2460+
return %f: vector<[17]xf32>
2461+
}
2462+
// CHECK-LABEL: func @transfer_read_write_1d_scalable
2463+
// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
2464+
// CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32>
2465+
// CHECK: %[[C7:.*]] = arith.constant 7.0
2466+
//
2467+
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
2468+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2469+
// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
2470+
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
2471+
//
2472+
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
2473+
// CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]xi32>
2474+
//
2475+
// 3. Create bound vector to compute in-bound mask:
2476+
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
2477+
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to i32
2478+
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
2479+
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
2480+
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
2481+
// CHECK-SAME: : vector<[17]xi32>
2482+
//
2483+
// 4. Create pass-through vector.
2484+
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32>
2485+
//
2486+
// 5. Bitcast to vector form.
2487+
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
2488+
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
2489+
//
2490+
// 6. Rewrite as a masked read.
2491+
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[gep]], %[[mask]],
2492+
// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} :
2493+
// CHECK-SAME: -> vector<[17]xf32>
2494+
//
2495+
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
2496+
// CHECK: %[[C0_b:.*]] = arith.constant 0 : index
2497+
// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
2498+
// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
2499+
//
2500+
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
2501+
// CHECK: %[[linearIndex_b:.*]] = llvm.intr.stepvector : vector<[17]xi32>
2502+
//
2503+
// 3. Create bound vector to compute in-bound mask:
2504+
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
2505+
// CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]] : index to i32
2506+
// CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
2507+
// CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
2508+
// CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
2509+
// CHECK-SAME: %[[boundVect_b]] : vector<[17]xi32>
2510+
//
2511+
// 4. Bitcast to vector form.
2512+
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
2513+
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
2514+
//
2515+
// 5. Rewrite as a masked write.
2516+
// CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]]
2517+
// CHECK-SAME: {alignment = 4 : i32} :
2518+
// CHECK-SAME: vector<[17]xf32>, vector<[17]xi1> into !llvm.ptr
2519+
24522520
// -----
24532521

2454-
func.func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xindex> {
2522+
func.func @transfer_read_write_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xindex> {
24552523
%f7 = arith.constant 7: index
24562524
%f = vector.transfer_read %A[%base], %f7
24572525
{permutation_map = affine_map<(d0) -> (d0)>} :
@@ -2461,7 +2529,7 @@ func.func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<
24612529
vector<17xindex>, memref<?xindex>
24622530
return %f: vector<17xindex>
24632531
}
2464-
// CHECK-LABEL: func @transfer_read_index_1d
2532+
// CHECK-LABEL: func @transfer_read_write_index_1d
24652533
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex>
24662534
// CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<17xindex>
24672535
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64>
@@ -2472,6 +2540,27 @@ func.func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<
24722540
// CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} :
24732541
// CHECK-SAME: vector<17xi64>, vector<17xi1> into !llvm.ptr
24742542

2543+
func.func @transfer_read_write_index_1d_scalable(%A : memref<?xindex>, %base: index) -> vector<[17]xindex> {
2544+
%f7 = arith.constant 7: index
2545+
%f = vector.transfer_read %A[%base], %f7
2546+
{permutation_map = affine_map<(d0) -> (d0)>} :
2547+
memref<?xindex>, vector<[17]xindex>
2548+
vector.transfer_write %f, %A[%base]
2549+
{permutation_map = affine_map<(d0) -> (d0)>} :
2550+
vector<[17]xindex>, memref<?xindex>
2551+
return %f: vector<[17]xindex>
2552+
}
2553+
// CHECK-LABEL: func @transfer_read_write_index_1d
2554+
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xindex>
2555+
// CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<[17]xindex>
2556+
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<[17]xindex> to vector<[17]xi64>
2557+
2558+
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} :
2559+
// CHECK-SAME: (!llvm.ptr, vector<[17]xi1>, vector<[17]xi64>) -> vector<[17]xi64>
2560+
2561+
// CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} :
2562+
// CHECK-SAME: vector<[17]xi64>, vector<[17]xi1> into !llvm.ptr
2563+
24752564
// -----
24762565

24772566
func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {
@@ -2501,9 +2590,34 @@ func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: i
25012590
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
25022591
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
25032592

2593+
func.func @transfer_read_2d_to_1d_scalable(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<[17]xf32> {
2594+
%f7 = arith.constant 7.0: f32
2595+
%f = vector.transfer_read %A[%base0, %base1], %f7
2596+
{permutation_map = affine_map<(d0, d1) -> (d1)>} :
2597+
memref<?x?xf32>, vector<[17]xf32>
2598+
return %f: vector<[17]xf32>
2599+
}
2600+
// CHECK-LABEL: func @transfer_read_2d_to_1d
2601+
// CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
2602+
// CHECK: %[[c1:.*]] = arith.constant 1 : index
2603+
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
2604+
//
2605+
// Compute the in-bound index (dim - offset)
2606+
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
2607+
//
2608+
// Create a vector with linear indices [ 0 .. vector_length - 1 ].
2609+
// CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]xi32>
2610+
//
2611+
// Create bound vector to compute in-bound mask:
2612+
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
2613+
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to i32
2614+
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
2615+
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
2616+
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
2617+
25042618
// -----
25052619

2506-
func.func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: index) -> vector<17xf32> {
2620+
func.func @transfer_read_write_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: index) -> vector<17xf32> {
25072621
%f7 = arith.constant 7.0: f32
25082622
%f = vector.transfer_read %A[%base], %f7
25092623
{permutation_map = affine_map<(d0) -> (d0)>} :
@@ -2513,7 +2627,7 @@ func.func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: ind
25132627
vector<17xf32>, memref<?xf32, 3>
25142628
return %f: vector<17xf32>
25152629
}
2516-
// CHECK-LABEL: func @transfer_read_1d_non_zero_addrspace
2630+
// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace
25172631
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
25182632
//
25192633
// 1. Check address space for GEP is correct.
@@ -2528,6 +2642,31 @@ func.func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: ind
25282642
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
25292643
// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
25302644

2645+
func.func @transfer_read_write_1d_non_zero_addrspace_scalable(%A : memref<?xf32, 3>, %base: index) -> vector<[17]xf32> {
2646+
%f7 = arith.constant 7.0: f32
2647+
%f = vector.transfer_read %A[%base], %f7
2648+
{permutation_map = affine_map<(d0) -> (d0)>} :
2649+
memref<?xf32, 3>, vector<[17]xf32>
2650+
vector.transfer_write %f, %A[%base]
2651+
{permutation_map = affine_map<(d0) -> (d0)>} :
2652+
vector<[17]xf32>, memref<?xf32, 3>
2653+
return %f: vector<[17]xf32>
2654+
}
2655+
// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace_scalable
2656+
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
2657+
//
2658+
// 1. Check address space for GEP is correct.
2659+
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
2660+
// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
2661+
//
2662+
// 2. Check address space of the memref is correct.
2663+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
2664+
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
2665+
//
2666+
// 3. Check address space for GEP is correct.
2667+
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
2668+
// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
2669+
25312670
// -----
25322671

25332672
func.func @transfer_read_1d_inbounds(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
@@ -2546,51 +2685,71 @@ func.func @transfer_read_1d_inbounds(%A : memref<?xf32>, %base: index) -> vector
25462685
// 2. Rewrite as a load.
25472686
// CHECK: %[[loaded:.*]] = llvm.load %[[gep]] {alignment = 4 : i64} : !llvm.ptr -> vector<17xf32>
25482687

2688+
func.func @transfer_read_1d_inbounds_scalable(%A : memref<?xf32>, %base: index) -> vector<[17]xf32> {
2689+
%f7 = arith.constant 7.0: f32
2690+
%f = vector.transfer_read %A[%base], %f7 {in_bounds = [true]} :
2691+
memref<?xf32>, vector<[17]xf32>
2692+
return %f: vector<[17]xf32>
2693+
}
2694+
// CHECK-LABEL: func @transfer_read_1d_inbounds_scalable
2695+
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
2696+
//
2697+
// 1. Bitcast to vector form.
2698+
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
2699+
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
2700+
//
2701+
// 2. Rewrite as a load.
2702+
// CHECK: %[[loaded:.*]] = llvm.load %[[gep]] {alignment = 4 : i64} : !llvm.ptr -> vector<[17]xf32>
2703+
25492704
// -----
25502705

2551-
// CHECK-LABEL: func @transfer_read_1d_mask
2706+
// CHECK-LABEL: func @transfer_read_write_1d_mask
25522707
// CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]>
25532708
// CHECK: %[[cmpi:.*]] = arith.cmpi slt
25542709
// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]]
25552710
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
2711+
// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
2712+
// CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]]
2713+
// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]]
25562714
// CHECK: return %[[r]]
2557-
func.func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32> {
2715+
func.func @transfer_read_write_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32> {
25582716
%m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
25592717
%f7 = arith.constant 7.0: f32
25602718
%f = vector.transfer_read %A[%base], %f7, %m : memref<?xf32>, vector<5xf32>
2719+
vector.transfer_write %f, %A[%base], %m : vector<5xf32>, memref<?xf32>
25612720
return %f: vector<5xf32>
25622721
}
25632722

2564-
// -----
2565-
2566-
// CHECK-LABEL: func @transfer_read_1d_scalable_mask
2567-
// CHECK: %[[passtru:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
2568-
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %[[passtru]] {alignment = 4 : i32} : (!llvm.ptr, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
2569-
// CHECK: return %[[r]] : vector<[4]xf32>
2570-
func.func @transfer_read_1d_scalable_mask(%arg0: memref<1x?xf32>, %mask: vector<[4]xi1>) -> vector<[4]xf32> {
2571-
%c0 = arith.constant 0 : index
2572-
%pad = arith.constant 0.0 : f32
2573-
%vec = vector.transfer_read %arg0[%c0, %c0], %pad, %mask {in_bounds = [true]} : memref<1x?xf32>, vector<[4]xf32>
2574-
return %vec : vector<[4]xf32>
2723+
// CHECK-LABEL: func @transfer_read_write_1d_mask_scalable
2724+
// CHECK-SAME: %[[mask:[a-zA-Z0-9]*]]: vector<[5]xi1>
2725+
// CHECK: %[[cmpi:.*]] = arith.cmpi slt
2726+
// CHECK: %[[mask1:.*]] = arith.andi %[[cmpi]], %[[mask]]
2727+
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask1]]
2728+
// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
2729+
// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi_1]], %[[mask]]
2730+
// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask2]]
2731+
// CHECK: return %[[r]]
2732+
func.func @transfer_read_write_1d_mask_scalable(%A : memref<?xf32>, %base : index, %m : vector<[5]xi1>) -> vector<[5]xf32> {
2733+
%f7 = arith.constant 7.0: f32
2734+
%f = vector.transfer_read %A[%base], %f7, %m : memref<?xf32>, vector<[5]xf32>
2735+
vector.transfer_write %f, %A[%base], %m : vector<[5]xf32>, memref<?xf32>
2736+
return %f: vector<[5]xf32>
25752737
}
25762738

25772739
// -----
2578-
// CHECK-LABEL: func @transfer_write_1d_scalable_mask
2579-
// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.ptr
2580-
func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<[4]xf32>, %mask: vector<[4]xi1>) {
2581-
%c0 = arith.constant 0 : index
2582-
vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true]} : vector<[4]xf32>, memref<1x?xf32>
2583-
return
2584-
}
25852740

2586-
// -----
2741+
// Can't lower xfer_read/xfer_write on tensors, but this shouldn't crash
25872742

2588-
// CHECK-LABEL: func @transfer_write_tensor
2743+
// CHECK-LABEL: func @transfer_read_write_tensor
2744+
// CHECK: vector.transfer_read
25892745
// CHECK: vector.transfer_write
2590-
func.func @transfer_write_tensor(%arg0: vector<4xf32>,%arg1: tensor<?xf32>) -> tensor<?xf32> {
2591-
%c0 = arith.constant 0 : index
2592-
%0 = vector.transfer_write %arg0, %arg1[%c0] : vector<4xf32>, tensor<?xf32>
2593-
return %0 : tensor<?xf32>
2746+
func.func @transfer_read_write_tensor(%A: tensor<?xf32>, %base : index) -> vector<4xf32> {
2747+
%f7 = arith.constant 7.0: f32
2748+
%c0 = arith.constant 0: index
2749+
%f = vector.transfer_read %A[%base], %f7 : tensor<?xf32>, vector<4xf32>
2750+
%w = vector.transfer_write %f, %A[%c0] : vector<4xf32>, tensor<?xf32>
2751+
"test.some_use"(%w) : (tensor<?xf32>) -> ()
2752+
return %f : vector<4xf32>
25942753
}
25952754

25962755
// -----

0 commit comments

Comments
 (0)