Skip to content

Commit d13da15

Browse files
authored
[mlir][spirv] Define KHR cooperative matrix properties (#66823)
1 parent ab2c104 commit d13da15

File tree

3 files changed

+71
-1
lines changed

3 files changed

+71
-1
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td

+28
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,29 @@ def SPIRV_LinkageAttributesAttr : SPIRV_Attr<"LinkageAttributes", "linkage_attri
5454
let assemblyFormat = "`<` struct(params) `>`";
5555
}
5656

57+
// Description of cooperative matrix operations supported on the
58+
// target. Represents `VkCooperativeMatrixPropertiesKHR`. See
59+
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkCooperativeMatrixPropertiesKHR.html
60+
def SPIRV_CooperativeMatrixPropertiesKHRAttr :
61+
SPIRV_Attr<"CooperativeMatrixPropertiesKHR", "coop_matrix_props_khr"> {
62+
let parameters = (ins
63+
"uint32_t":$m_size,
64+
"uint32_t":$n_size,
65+
"uint32_t":$k_size,
66+
"mlir::Type":$a_type,
67+
"mlir::Type":$b_type,
68+
"mlir::Type":$c_type,
69+
"mlir::Type":$result_type,
70+
"bool":$acc_sat,
71+
"mlir::spirv::ScopeAttr":$scope
72+
);
73+
let assemblyFormat = "`<` struct(params) `>`";
74+
}
75+
76+
def SPIRV_CooperativeMatrixPropertiesKHRArrayAttr :
77+
TypedArrayAttrBase<SPIRV_CooperativeMatrixPropertiesKHRAttr,
78+
"CooperativeMatrixPropertiesKHR array attribute">;
79+
5780
// Description of cooperative matrix operations supported on the
5881
// target. Represents `VkCooperativeMatrixPropertiesNV`. See
5982
// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkCooperativeMatrixPropertiesNV.html
@@ -130,6 +153,11 @@ def SPIRV_ResourceLimitsAttr : SPIRV_Attr<"ResourceLimits", "resource_limits"> {
130153

131154
// The configurations of cooperative matrix operations
132155
// supported. Default is an empty list.
156+
DefaultValuedParameter<
157+
"ArrayAttr",
158+
"nullptr"
159+
>:$cooperative_matrix_properties_khr,
160+
133161
DefaultValuedParameter<
134162
"ArrayAttr",
135163
"nullptr"

mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ spirv::getDefaultResourceLimits(MLIRContext *context) {
166166
/*subgroup_size=*/32,
167167
/*min_subgroup_size=*/std::nullopt,
168168
/*max_subgroup_size=*/std::nullopt,
169-
/*cooperative_matrix_properties_nv=*/ArrayAttr());
169+
/*cooperative_matrix_properties_khr=*/ArrayAttr{},
170+
/*cooperative_matrix_properties_nv=*/ArrayAttr{});
170171
}
171172

172173
StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }

mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir

+41
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,47 @@ func.func @target_env_extra_fields() attributes {
208208

209209
// -----
210210

211+
func.func @target_env_cooperative_matrix_khr() attributes{
212+
// CHECK: spirv.target_env = #spirv.target_env<
213+
// CHECK-SAME: SPV_KHR_cooperative_matrix
214+
// CHECK-SAME: #spirv.coop_matrix_props_khr<
215+
// CHECK-SAME: m_size = 8, n_size = 8, k_size = 32,
216+
// CHECK-SAME: a_type = i8, b_type = i8, c_type = i32,
217+
// CHECK-SAME: result_type = i32, acc_sat = true, scope = <Subgroup>>
218+
// CHECK-SAME: #spirv.coop_matrix_props_khr<
219+
// CHECK-SAME: m_size = 8, n_size = 8, k_size = 16,
220+
// CHECK-SAME: a_type = f16, b_type = f16, c_type = f16,
221+
// CHECK-SAME: result_type = f16, acc_sat = false, scope = <Subgroup>>
222+
spirv.target_env = #spirv.target_env<
223+
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class,
224+
SPV_KHR_cooperative_matrix]>,
225+
#spirv.resource_limits<
226+
cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<
227+
m_size = 8,
228+
n_size = 8,
229+
k_size = 32,
230+
a_type = i8,
231+
b_type = i8,
232+
c_type = i32,
233+
result_type = i32,
234+
acc_sat = true,
235+
scope = #spirv.scope<Subgroup>
236+
>, #spirv.coop_matrix_props_khr<
237+
m_size = 8,
238+
n_size = 8,
239+
k_size = 16,
240+
a_type = f16,
241+
b_type = f16,
242+
c_type = f16,
243+
result_type = f16,
244+
acc_sat = false,
245+
scope = #spirv.scope<Subgroup>
246+
>]
247+
>>
248+
} { return }
249+
250+
// -----
251+
211252
func.func @target_env_cooperative_matrix_nv() attributes{
212253
// CHECK: spirv.target_env = #spirv.target_env<
213254
// CHECK-SAME: SPV_NV_cooperative_matrix

0 commit comments

Comments
 (0)