Skip to content

Commit 74d8670

Browse files
committed
[mlir][SVE] Add an e2e test for vector.contract
Adds an end-to-end test for `vector.contract` that targets SVE (i.e. scalable vectors). Note that this requires lifting the restriction on `vector.outerproduct` (to which `vector.contract` is lowered to) that would deem the following as invalid by the Op verifier (*): ``` vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<3xf32>, vector<[2]xf32> ``` This is indeed valid as the end-to-end test demonstrates (at least when compiling for SVE). Depends on #68794
1 parent 64025b8 commit 74d8670

File tree

4 files changed

+191
-19
lines changed

4 files changed

+191
-19
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -3067,9 +3067,12 @@ LogicalResult OuterProductOp::verify() {
30673067
return emitOpError("expected #1 operand dim to match result dim #1");
30683068
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
30693069
return emitOpError("expected #2 operand dim to match result dim #2");
3070-
if (vRHS.isScalable() != vLHS.isScalable())
3071-
return emitOpError("expected either all or none of vector operands #1 "
3072-
"and #2 to be scalable");
3070+
if (vLHS.isScalable() && !vRHS.isScalable()) {
3071+
// This restriction reflects what's currently supported in terms of
3072+
// scalable vectors. However, we could relax this if there's a use case.
3073+
return emitOpError(
3074+
"expected either both or only #2 operand dim to be scalable");
3075+
}
30733076
} else {
30743077
// An AXPY operation.
30753078
if (vRES.getRank() != 1)

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir

+44-15
Original file line numberDiff line numberDiff line change
@@ -79,21 +79,21 @@ func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf3
7979
}
8080

8181
// CHECK-LABEL: func.func @masked_extract_contract4(
82-
// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
83-
// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
84-
// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
85-
// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
86-
// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
87-
// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
88-
// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
89-
// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
90-
// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
91-
// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
92-
// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
93-
// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
94-
// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
95-
// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
96-
// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
82+
// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
83+
// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
84+
// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
85+
// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
86+
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
87+
// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
88+
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
89+
// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
90+
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
91+
// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
92+
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
93+
// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
94+
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
95+
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
96+
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
9797

9898
func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
9999
%arg1: vector<5x7xf32>,
@@ -104,6 +104,35 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
104104
return %0 : vector<3x7xf32>
105105
}
106106

107+
// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
108+
// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
109+
// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
110+
// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
111+
// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
112+
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
113+
// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
114+
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
115+
// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
116+
// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
117+
// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
118+
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
119+
// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
120+
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
121+
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
122+
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
123+
124+
// Note that only the J dimension is scalable in this example. In theory, all
125+
// dimensions could be be scalable, but there is no target yet for which this
126+
// would make sense.
127+
func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
128+
%arg1: vector<5x[7]xf32>,
129+
%arg2: vector<3x[7]xf32>,
130+
%m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
131+
%0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
132+
: vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
133+
return %0 : vector<3x[7]xf32>
134+
}
135+
107136
// CHECK-LABEL: func @matmul
108137
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
109138
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,

mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
2121
%0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
2222
%1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>
2323

24-
// expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
24+
// expected-error @+1 {{expected either both or only #2 operand dim to be scalable}}
2525
%op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
26+
27+
return
2628
}
29+
2730
// -----
2831

2932
func.func @invalid_outerproduct1(%src : memref<?xf32>) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// DEFINE: %{compile} = mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule\
2+
// DEFINE: -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage\
3+
// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm
4+
// DEFINE: %{entry} =
5+
// DEFINE: %{run} = %mcr_aarch64_cmd -e=%{entry} -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext
6+
7+
// REDEFINE: %{entry} = entry_i32
8+
// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=I32
9+
10+
// REDEFINE: %{entry} = entry_f32
11+
// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=F32
12+
13+
#matmat_accesses = [
14+
affine_map<(i, j, k) -> (i, k)>,
15+
affine_map<(i, j, k) -> (k, j)>,
16+
affine_map<(i, j, k) -> (i, j)>
17+
]
18+
#matmat_trait = {
19+
indexing_maps = #matmat_accesses,
20+
iterator_types = ["parallel", "parallel", "reduction"]
21+
}
22+
23+
func.func @entry_i32() {
24+
%vscale = vector.vscale
25+
26+
%c0 = arith.constant 0 : index
27+
%c2 = arith.constant 2 : index
28+
%c3 = arith.constant 3 : index
29+
%c5 = arith.constant 5 : index
30+
%n_rows = arith.muli %vscale, %c2 : index
31+
32+
%cst = arith.constant 0: i32
33+
%i32_123 = arith.constant 123 : i32
34+
%i32_314 = arith.constant 314 : i32
35+
36+
// Allocate and initialize matrix A
37+
%A_alloc = memref.alloca() : memref<3x5xi32>
38+
linalg.fill ins(%i32_123 : i32) outs(%A_alloc :memref<3x5xi32>)
39+
%mask_a = vector.create_mask %c3, %c5 : vector<3x5xi1>
40+
%vector_a = vector.transfer_read %A_alloc[%c0, %c0], %cst, %mask_a {in_bounds = [true, true]} : memref<3x5xi32>, vector<3x5xi32>
41+
42+
// Allocate and initialize matrix B
43+
%B_alloc = memref.alloca(%n_rows) : memref<5x?xi32>
44+
linalg.fill ins(%i32_123 : i32) outs(%B_alloc :memref<5x?xi32>)
45+
%mask_b = vector.create_mask %c5, %n_rows : vector<5x[2]xi1>
46+
%vector_b = vector.transfer_read %B_alloc[%c0, %c0], %cst, %mask_b {in_bounds = [true, true]} : memref<5x?xi32>, vector<5x[2]xi32>
47+
48+
// Allocate and initialize matrix C
49+
%C_alloc = memref.alloca(%n_rows) : memref<3x?xi32>
50+
linalg.fill ins(%i32_314 : i32) outs(%C_alloc :memref<3x?xi32>)
51+
%mask_c = vector.create_mask %c3, %n_rows : vector<3x[2]xi1>
52+
%vector_c = vector.transfer_read %C_alloc[%c0, %c0], %cst, %mask_c {in_bounds = [true, true]} : memref<3x?xi32>, vector<3x[2]xi32>
53+
54+
// Matmul
55+
%m = vector.create_mask %c3, %n_rows, %c5 : vector<3x[2]x5xi1>
56+
%0 = vector.mask %m { vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
57+
: vector<3x5xi32>, vector<5x[2]xi32> into vector<3x[2]xi32> } : vector<3x[2]x5xi1> -> vector<3x[2]xi32>
58+
59+
// Print the output
60+
%slice1 = vector.extract %0[0] : vector<[2]xi32> from vector<3x[2]xi32>
61+
// I32: ( 75959, 75959, 75959, 75959
62+
vector.print %slice1 : vector<[2]xi32>
63+
%slice2 = vector.extract %0[1] : vector<[2]xi32> from vector<3x[2]xi32>
64+
// I32-NEXT: ( 75959, 75959, 75959, 75959
65+
vector.print %slice2 : vector<[2]xi32>
66+
%slice3 = vector.extract %0[2] : vector<[2]xi32> from vector<3x[2]xi32>
67+
// I32-NEXT: ( 75959, 75959, 75959, 75959
68+
vector.print %slice3 : vector<[2]xi32>
69+
70+
// CHECK: SVE: END OF TEST OUTPUT
71+
vector.print str "SVE: END OF TEST OUTPUT"
72+
73+
return
74+
}
75+
76+
func.func @entry_f32() {
77+
%vscale = vector.vscale
78+
79+
%c0 = arith.constant 0 : index
80+
%c2 = arith.constant 2 : index
81+
%c3 = arith.constant 3 : index
82+
%c5 = arith.constant 5 : index
83+
%n_rows = arith.muli %vscale, %c2 : index
84+
85+
%cst = arith.constant 0.0: f32
86+
%f32_123 = arith.constant 1.23 : f32
87+
%f32_314 = arith.constant 3.14 : f32
88+
89+
// Allocate and initialize matrix A
90+
%A_alloc = memref.alloca() : memref<3x5xf32>
91+
linalg.fill ins(%f32_123 : f32) outs(%A_alloc :memref<3x5xf32>)
92+
%mask_a = vector.create_mask %c3, %c5 : vector<3x5xi1>
93+
%vector_a = vector.transfer_read %A_alloc[%c0, %c0], %cst, %mask_a {in_bounds = [true, true]} : memref<3x5xf32>, vector<3x5xf32>
94+
95+
// Allocate and initialize matrix B
96+
%B_alloc = memref.alloca(%n_rows) : memref<5x?xf32>
97+
linalg.fill ins(%f32_123 : f32) outs(%B_alloc :memref<5x?xf32>)
98+
%mask_b = vector.create_mask %c5, %n_rows : vector<5x[2]xi1>
99+
%vector_b = vector.transfer_read %B_alloc[%c0, %c0], %cst, %mask_b {in_bounds = [true, true]} : memref<5x?xf32>, vector<5x[2]xf32>
100+
101+
// Allocate and initialize matrix C
102+
%C_alloc = memref.alloca(%n_rows) : memref<3x?xf32>
103+
linalg.fill ins(%f32_314 : f32) outs(%C_alloc :memref<3x?xf32>)
104+
%mask_c = vector.create_mask %c3, %n_rows : vector<3x[2]xi1>
105+
%vector_c = vector.transfer_read %C_alloc[%c0, %c0], %cst, %mask_c {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[2]xf32>
106+
107+
// Matmul
108+
%m = vector.create_mask %c3, %n_rows, %c5 : vector<3x[2]x5xi1>
109+
%0 = vector.mask %m { vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
110+
: vector<3x5xf32>, vector<5x[2]xf32> into vector<3x[2]xf32> } : vector<3x[2]x5xi1> -> vector<3x[2]xf32>
111+
112+
// Print the output
113+
%slice1 = vector.extract %0[0] : vector<[2]xf32> from vector<3x[2]xf32>
114+
// F32: ( 10.7045, 10.7045, 10.7045, 10.7045
115+
vector.print %slice1 : vector<[2]xf32>
116+
%slice2 = vector.extract %0[1] : vector<[2]xf32> from vector<3x[2]xf32>
117+
// F32-NEXT: ( 10.7045, 10.7045, 10.7045, 10.7045
118+
vector.print %slice2 : vector<[2]xf32>
119+
%slice3 = vector.extract %0[2] : vector<[2]xf32> from vector<3x[2]xf32>
120+
// F32-NEXT: ( 10.7045, 10.7045, 10.7045, 10.7045
121+
vector.print %slice3 : vector<[2]xf32>
122+
123+
// CHECK: SVE: END OF TEST OUTPUT
124+
vector.print str "SVE: END OF TEST OUTPUT"
125+
126+
return
127+
}
128+
129+
transform.sequence failures(propagate) {
130+
^bb1(%module_op: !transform.any_op):
131+
%f = transform.structured.match ops{["func.func"]} in %module_op
132+
: (!transform.any_op) -> !transform.any_op
133+
134+
transform.apply_patterns to %f {
135+
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
136+
} : !transform.any_op
137+
}

0 commit comments

Comments
 (0)