|
| 1 | +// DEFINE: %{compile} = mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule\ |
| 2 | +// DEFINE: -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage\ |
| 3 | +// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm |
| 4 | +// DEFINE: %{entry} = |
| 5 | +// DEFINE: %{run} = %mcr_aarch64_cmd -e=%{entry} -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext |
| 6 | + |
| 7 | +// REDEFINE: %{entry} = entry_i32 |
| 8 | +// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=I32 |
| 9 | + |
| 10 | +// REDEFINE: %{entry} = entry_f32 |
| 11 | +// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=F32 |
| 12 | + |
| 13 | +#matmat_accesses = [ |
| 14 | + affine_map<(i, j, k) -> (i, k)>, |
| 15 | + affine_map<(i, j, k) -> (k, j)>, |
| 16 | + affine_map<(i, j, k) -> (i, j)> |
| 17 | +] |
| 18 | +#matmat_trait = { |
| 19 | + indexing_maps = #matmat_accesses, |
| 20 | + iterator_types = ["parallel", "parallel", "reduction"] |
| 21 | +} |
| 22 | + |
| 23 | +func.func @entry_i32() { |
| 24 | + %vscale = vector.vscale |
| 25 | + |
| 26 | + %c0 = arith.constant 0 : index |
| 27 | + %c2 = arith.constant 2 : index |
| 28 | + %c3 = arith.constant 3 : index |
| 29 | + %c5 = arith.constant 5 : index |
| 30 | + %n_rows = arith.muli %vscale, %c2 : index |
| 31 | + |
| 32 | + %cst = arith.constant 0: i32 |
| 33 | + %i32_123 = arith.constant 123 : i32 |
| 34 | + %i32_314 = arith.constant 314 : i32 |
| 35 | + |
| 36 | + // Allocate and initialize matrix A |
| 37 | + %A_alloc = memref.alloca() : memref<3x5xi32> |
| 38 | + linalg.fill ins(%i32_123 : i32) outs(%A_alloc :memref<3x5xi32>) |
| 39 | + %mask_a = vector.create_mask %c3, %c5 : vector<3x5xi1> |
| 40 | + %vector_a = vector.transfer_read %A_alloc[%c0, %c0], %cst, %mask_a {in_bounds = [true, true]} : memref<3x5xi32>, vector<3x5xi32> |
| 41 | + |
| 42 | + // Allocate and initialize matrix B |
| 43 | + %B_alloc = memref.alloca(%n_rows) : memref<5x?xi32> |
| 44 | + linalg.fill ins(%i32_123 : i32) outs(%B_alloc :memref<5x?xi32>) |
| 45 | + %mask_b = vector.create_mask %c5, %n_rows : vector<5x[2]xi1> |
| 46 | + %vector_b = vector.transfer_read %B_alloc[%c0, %c0], %cst, %mask_b {in_bounds = [true, true]} : memref<5x?xi32>, vector<5x[2]xi32> |
| 47 | + |
| 48 | + // Allocate and initialize matrix C |
| 49 | + %C_alloc = memref.alloca(%n_rows) : memref<3x?xi32> |
| 50 | + linalg.fill ins(%i32_314 : i32) outs(%C_alloc :memref<3x?xi32>) |
| 51 | + %mask_c = vector.create_mask %c3, %n_rows : vector<3x[2]xi1> |
| 52 | + %vector_c = vector.transfer_read %C_alloc[%c0, %c0], %cst, %mask_c {in_bounds = [true, true]} : memref<3x?xi32>, vector<3x[2]xi32> |
| 53 | + |
| 54 | + // Matmul |
| 55 | + %m = vector.create_mask %c3, %n_rows, %c5 : vector<3x[2]x5xi1> |
| 56 | + %0 = vector.mask %m { vector.contract #matmat_trait %vector_a, %vector_b, %vector_c |
| 57 | + : vector<3x5xi32>, vector<5x[2]xi32> into vector<3x[2]xi32> } : vector<3x[2]x5xi1> -> vector<3x[2]xi32> |
| 58 | + |
| 59 | + // Print the output |
| 60 | + %slice1 = vector.extract %0[0] : vector<[2]xi32> from vector<3x[2]xi32> |
| 61 | + // I32: ( 75959, 75959, 75959, 75959 |
| 62 | + vector.print %slice1 : vector<[2]xi32> |
| 63 | + %slice2 = vector.extract %0[1] : vector<[2]xi32> from vector<3x[2]xi32> |
| 64 | + // I32-NEXT: ( 75959, 75959, 75959, 75959 |
| 65 | + vector.print %slice2 : vector<[2]xi32> |
| 66 | + %slice3 = vector.extract %0[2] : vector<[2]xi32> from vector<3x[2]xi32> |
| 67 | + // I32-NEXT: ( 75959, 75959, 75959, 75959 |
| 68 | + vector.print %slice3 : vector<[2]xi32> |
| 69 | + |
| 70 | + // CHECK: SVE: END OF TEST OUTPUT |
| 71 | + vector.print str "SVE: END OF TEST OUTPUT" |
| 72 | + |
| 73 | + return |
| 74 | +} |
| 75 | + |
| 76 | +func.func @entry_f32() { |
| 77 | + %vscale = vector.vscale |
| 78 | + |
| 79 | + %c0 = arith.constant 0 : index |
| 80 | + %c2 = arith.constant 2 : index |
| 81 | + %c3 = arith.constant 3 : index |
| 82 | + %c5 = arith.constant 5 : index |
| 83 | + %n_rows = arith.muli %vscale, %c2 : index |
| 84 | + |
| 85 | + %cst = arith.constant 0.0: f32 |
| 86 | + %f32_123 = arith.constant 1.23 : f32 |
| 87 | + %f32_314 = arith.constant 3.14 : f32 |
| 88 | + |
| 89 | + // Allocate and initialize matrix A |
| 90 | + %A_alloc = memref.alloca() : memref<3x5xf32> |
| 91 | + linalg.fill ins(%f32_123 : f32) outs(%A_alloc :memref<3x5xf32>) |
| 92 | + %mask_a = vector.create_mask %c3, %c5 : vector<3x5xi1> |
| 93 | + %vector_a = vector.transfer_read %A_alloc[%c0, %c0], %cst, %mask_a {in_bounds = [true, true]} : memref<3x5xf32>, vector<3x5xf32> |
| 94 | + |
| 95 | + // Allocate and initialize matrix B |
| 96 | + %B_alloc = memref.alloca(%n_rows) : memref<5x?xf32> |
| 97 | + linalg.fill ins(%f32_123 : f32) outs(%B_alloc :memref<5x?xf32>) |
| 98 | + %mask_b = vector.create_mask %c5, %n_rows : vector<5x[2]xi1> |
| 99 | + %vector_b = vector.transfer_read %B_alloc[%c0, %c0], %cst, %mask_b {in_bounds = [true, true]} : memref<5x?xf32>, vector<5x[2]xf32> |
| 100 | + |
| 101 | + // Allocate and initialize matrix C |
| 102 | + %C_alloc = memref.alloca(%n_rows) : memref<3x?xf32> |
| 103 | + linalg.fill ins(%f32_314 : f32) outs(%C_alloc :memref<3x?xf32>) |
| 104 | + %mask_c = vector.create_mask %c3, %n_rows : vector<3x[2]xi1> |
| 105 | + %vector_c = vector.transfer_read %C_alloc[%c0, %c0], %cst, %mask_c {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[2]xf32> |
| 106 | + |
| 107 | + // Matmul |
| 108 | + %m = vector.create_mask %c3, %n_rows, %c5 : vector<3x[2]x5xi1> |
| 109 | + %0 = vector.mask %m { vector.contract #matmat_trait %vector_a, %vector_b, %vector_c |
| 110 | + : vector<3x5xf32>, vector<5x[2]xf32> into vector<3x[2]xf32> } : vector<3x[2]x5xi1> -> vector<3x[2]xf32> |
| 111 | + |
| 112 | + // Print the output |
| 113 | + %slice1 = vector.extract %0[0] : vector<[2]xf32> from vector<3x[2]xf32> |
| 114 | + // F32: ( 10.7045, 10.7045, 10.7045, 10.7045 |
| 115 | + vector.print %slice1 : vector<[2]xf32> |
| 116 | + %slice2 = vector.extract %0[1] : vector<[2]xf32> from vector<3x[2]xf32> |
| 117 | + // F32-NEXT: ( 10.7045, 10.7045, 10.7045, 10.7045 |
| 118 | + vector.print %slice2 : vector<[2]xf32> |
| 119 | + %slice3 = vector.extract %0[2] : vector<[2]xf32> from vector<3x[2]xf32> |
| 120 | + // F32-NEXT: ( 10.7045, 10.7045, 10.7045, 10.7045 |
| 121 | + vector.print %slice3 : vector<[2]xf32> |
| 122 | + |
| 123 | + // CHECK: SVE: END OF TEST OUTPUT |
| 124 | + vector.print str "SVE: END OF TEST OUTPUT" |
| 125 | + |
| 126 | + return |
| 127 | +} |
| 128 | + |
| 129 | +transform.sequence failures(propagate) { |
| 130 | +^bb1(%module_op: !transform.any_op): |
| 131 | + %f = transform.structured.match ops{["func.func"]} in %module_op |
| 132 | + : (!transform.any_op) -> !transform.any_op |
| 133 | + |
| 134 | + transform.apply_patterns to %f { |
| 135 | + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" |
| 136 | + } : !transform.any_op |
| 137 | +} |
0 commit comments