@@ -79,21 +79,21 @@ func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf3
79
79
}
80
80
81
81
// CHECK-LABEL: func.func @masked_extract_contract4(
82
- // CHECK-SAME: %[[VAL_0:.*]] : vector<3x5xf32>,
83
- // CHECK-SAME: %[[VAL_1:.*]] : vector<5x7xf32>,
84
- // CHECK-SAME: %[[VAL_2:.*]] : vector<3x7xf32>,
85
- // CHECK-SAME: %[[VAL_3 :.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
86
- // CHECK: %[[VAL_5 :.*]] = vector.transpose %[[VAL_3 ]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
87
- // CHECK: %[[VAL_8 :.*]] = vector.extract %[[VAL_5 ]][0] : vector<3x7xi1> from vector<5x3x7xi1>
88
- // CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8 ]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
89
- // CHECK: %[[VAL_12 :.*]] = vector.extract %[[VAL_5 ]][1] : vector<3x7xi1> from vector<5x3x7xi1>
90
- // CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12 ]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
91
- // CHECK: %[[VAL_16 :.*]] = vector.extract %[[VAL_5 ]][2] : vector<3x7xi1> from vector<5x3x7xi1>
92
- // CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16 ]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
93
- // CHECK: %[[VAL_20 :.*]] = vector.extract %[[VAL_5 ]][3] : vector<3x7xi1> from vector<5x3x7xi1>
94
- // CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20 ]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
95
- // CHECK: %[[VAL_24 :.*]] = vector.extract %[[VAL_5 ]][4] : vector<3x7xi1> from vector<5x3x7xi1>
96
- // CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24 ]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
82
+ // CHECK-SAME: %{{.*}} : vector<3x5xf32>,
83
+ // CHECK-SAME: %{{.*}} : vector<5x7xf32>,
84
+ // CHECK-SAME: %{{.*}} : vector<3x7xf32>,
85
+ // CHECK-SAME: %[[IN_MASK :.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
86
+ // CHECK: %[[T_MASK :.*]] = vector.transpose %[[IN_MASK ]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
87
+ // CHECK: %[[T_MASK_R0 :.*]] = vector.extract %[[T_MASK ]][0] : vector<3x7xi1> from vector<5x3x7xi1>
88
+ // CHECK: %{{.*}} = vector.mask %[[T_MASK_R0 ]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
89
+ // CHECK: %[[T_MASK_R1 :.*]] = vector.extract %[[T_MASK ]][1] : vector<3x7xi1> from vector<5x3x7xi1>
90
+ // CHECK: %{{.*}} = vector.mask %[[T_MASK_R1 ]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
91
+ // CHECK: %[[T_MASK_R2 :.*]] = vector.extract %[[T_MASK ]][2] : vector<3x7xi1> from vector<5x3x7xi1>
92
+ // CHECK: %{{.*}} = vector.mask %[[T_MASK_R2 ]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
93
+ // CHECK: %[[T_MASK_R3 :.*]] = vector.extract %[[T_MASK ]][3] : vector<3x7xi1> from vector<5x3x7xi1>
94
+ // CHECK: %{{.*}} = vector.mask %[[T_MASK_R3 ]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
95
+ // CHECK: %[[T_MASK_R4 :.*]] = vector.extract %[[T_MASK ]][4] : vector<3x7xi1> from vector<5x3x7xi1>
96
+ // CHECK: %{{.*}} = vector.mask %[[T_MASK_R4 ]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
97
97
98
98
func.func @masked_extract_contract4 (%arg0: vector <3 x5 xf32 >,
99
99
%arg1: vector <5 x7 xf32 >,
@@ -104,6 +104,35 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
104
104
return %0 : vector <3 x7 xf32 >
105
105
}
106
106
107
+ // CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
108
+ // CHECK-SAME: %{{.*}}: vector<3x5xf32>,
109
+ // CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
110
+ // CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
111
+ // CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
112
+ // CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
113
+ // CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
114
+ // CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
115
+ // CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
116
+ // CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
117
+ // CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
118
+ // CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
119
+ // CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
120
+ // CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
121
+ // CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
122
+ // CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
123
+
124
+ // Note that only the J dimension is scalable in this example. In theory, all
125
+ // dimensions could be be scalable, but there is no target yet for which this
126
+ // would make sense.
127
+ func.func @masked_extract_contract4_scalable_J_dim (%arg0: vector <3 x5 xf32 >,
128
+ %arg1: vector <5 x[7 ]xf32 >,
129
+ %arg2: vector <3 x[7 ]xf32 >,
130
+ %m : vector <3 x[7 ]x5 xi1 >) -> vector <3 x[7 ]xf32 > {
131
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0 , %arg1 , %arg2
132
+ : vector <3 x5 xf32 >, vector <5 x[7 ]xf32 > into vector <3 x[7 ]xf32 > } : vector <3 x[7 ]x5 xi1 > -> vector <3 x[7 ]xf32 >
133
+ return %0 : vector <3 x[7 ]xf32 >
134
+ }
135
+
107
136
// CHECK-LABEL: func @matmul
108
137
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
109
138
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
0 commit comments