diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index cc4417077d459..2ce3ad875fa45 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4053,6 +4053,8 @@ def SPIRV_KHR_CMU_MatrixA : I32EnumAttrCase<"MatrixA", 0>; def SPIRV_KHR_CMU_MatrixB : I32EnumAttrCase<"MatrixB", 1>; def SPIRV_KHR_CMU_MatrixAcc : I32EnumAttrCase<"MatrixAcc", 2>; +// NOTE: This is an attribute in the SPIR-V *dialect* but a constant () in +// SPIR-V proper. def SPIRV_KHR_CooperativeMatrixUseAttr : SPIRV_I32EnumAttr<"CooperativeMatrixUseKHR", "valid SPIR-V Cooperative Matrix Use (KHR)", @@ -4064,6 +4066,8 @@ def SPIRV_KHR_CooperativeMatrixUseAttr : def SPIRV_KHR_CML_RowMajor : I32EnumAttrCase<"RowMajor", 0>; def SPIRV_KHR_CML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>; +// NOTE: This is an attribute in the SPIR-V *dialect* but a constant () in +// SPIR-V proper. def SPIRV_KHR_CooperativeMatrixLayoutAttr : SPIRV_I32EnumAttr<"CooperativeMatrixLayoutKHR", "valid SPIR-V Cooperative Matrix Layout (KHR)", diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td index 3ce43c7e2b1fc..b5ea0774f589d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -55,12 +55,14 @@ def SPIRV_KHRCooperativeMatrixLengthOp : ]; let arguments = (ins - TypeAttr:$cooperative_matrix_type + TypeAttrOf:$cooperative_matrix_type ); let results = (outs SPIRV_Int32:$result ); + + let hasVerifier = false; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp index 600813f361a47..77dbf130c7778 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp @@ -19,20 +19,6 @@ using namespace mlir::spirv::AttrNames; namespace mlir::spirv { -//===----------------------------------------------------------------------===// -// spirv.KHR.CooperativeMatrixLength -//===----------------------------------------------------------------------===// - -LogicalResult KHRCooperativeMatrixLengthOp::verify() { - if (!isa(getCooperativeMatrixType())) { - return emitOpError( - "type attribute must be a '!spirv.coopmatrix' type, found ") - << getCooperativeMatrixType() << " instead"; - } - - return success(); -} - //===----------------------------------------------------------------------===// // spirv.KHR.CooperativeMatrixLoad //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index 78afcc7003eff..7510e1e2eb9b6 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -164,6 +164,7 @@ LogicalResult spirv::Deserializer::processInstruction( case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: + case spirv::Opcode::OpTypeCooperativeMatrixKHR: case spirv::Opcode::OpTypeCooperativeMatrixNV: return processType(opcode, operands); case spirv::Opcode::OpTypeForwardPointer: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index b84d1d9c21879..ce8b3ab389460 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -765,8 +765,10 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, } break; case spirv::Opcode::OpTypeArray: return processArrayType(operands); + case spirv::Opcode::OpTypeCooperativeMatrixKHR: + return processCooperativeMatrixTypeKHR(operands); case spirv::Opcode::OpTypeCooperativeMatrixNV: - return processCooperativeMatrixType(operands); + return processCooperativeMatrixTypeNV(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); case spirv::Opcode::OpTypeJointMatrixINTEL: @@ -900,32 +902,76 @@ spirv::Deserializer::processFunctionType(ArrayRef operands) { return success(); } -LogicalResult -spirv::Deserializer::processCooperativeMatrixType(ArrayRef operands) { +LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR( + ArrayRef operands) { + if (operands.size() != 6) { + return emitError(unknownLoc, + "OpTypeCooperativeMatrixKHR must have element type, " + "scope, row and column parameters, and use"); + } + + Type elementTy = getType(operands[1]); + if (!elementTy) { + return emitError(unknownLoc, + "OpTypeCooperativeMatrixKHR references undefined ") + << operands[1]; + } + + std::optional scope = + spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); + if (!scope) { + return emitError( + unknownLoc, + "OpTypeCooperativeMatrixKHR references undefined scope ") + << operands[2]; + } + + unsigned rows = getConstantInt(operands[3]).getInt(); + unsigned columns = getConstantInt(operands[4]).getInt(); + + std::optional use = + spirv::symbolizeCooperativeMatrixUseKHR( + getConstantInt(operands[5]).getInt()); + if (!use) { + return emitError( + unknownLoc, + "OpTypeCooperativeMatrixKHR references undefined use ") + << operands[5]; + } + + typeMap[operands[0]] = + spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use); + return success(); +} + +LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV( + ArrayRef operands) { if (operands.size() != 5) { - return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element " + return emitError(unknownLoc, "OpTypeCooperativeMatrixNV must have element " "type and row x column parameters"); } Type elementTy = getType(operands[1]); if (!elementTy) { return emitError(unknownLoc, - "OpTypeCooperativeMatrix references undefined ") + "OpTypeCooperativeMatrixNV references undefined ") << operands[1]; } - auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); + std::optional scope = + spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); if (!scope) { - return emitError(unknownLoc, - "OpTypeCooperativeMatrix references undefined scope ") + return emitError( + unknownLoc, + "OpTypeCooperativeMatrixNV references undefined scope ") << operands[2]; } unsigned rows = getConstantInt(operands[3]).getInt(); unsigned columns = getConstantInt(operands[4]).getInt(); - typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get( - elementTy, scope.value(), rows, columns); + typeMap[operands[0]] = + spirv::CooperativeMatrixNVType::get(elementTy, *scope, rows, columns); return success(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 613e4f6738df6..69be47851ef3c 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -254,7 +254,9 @@ class Deserializer { LogicalResult processArrayType(ArrayRef operands); - LogicalResult processCooperativeMatrixType(ArrayRef operands); + LogicalResult processCooperativeMatrixTypeKHR(ArrayRef operands); + + LogicalResult processCooperativeMatrixTypeNV(ArrayRef operands); LogicalResult processFunctionType(ArrayRef operands); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 1ef8ff043e690..dad085e21b427 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -593,6 +593,28 @@ LogicalResult Serializer::prepareBasicType( return success(); } + if (auto cooperativeMatrixType = + dyn_cast(type)) { + uint32_t elementTypeID = 0; + if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), + elementTypeID, serializationCtx))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR; + auto getConstantOp = [&](uint32_t id) { + auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); + return prepareConstantInt(loc, attr); + }; + operands.push_back(elementTypeID); + operands.push_back( + getConstantOp(static_cast(cooperativeMatrixType.getScope()))); + operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); + operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); + operands.push_back( + getConstantOp(static_cast(cooperativeMatrixType.getUse()))); + return success(); + } + if (auto cooperativeMatrixType = dyn_cast(type)) { uint32_t elementTypeID = 0; diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir index 03acb0c08b275..40736367520e8 100644 --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir @@ -14,7 +14,7 @@ spirv.func @cooperative_matrix_length() -> i32 "None" { // ----- spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" { - // expected-error @+1 {{'spirv.KHR.CooperativeMatrixLength' op type attribute must be a '!spirv.coopmatrix'}} + // expected-error @+1 {{'cooperative_matrix_type' failed to satisfy constraint: type attribute of any SPIR-V cooperative matrix type}} %0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.ReturnValue %0 : i32 } diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir new file mode 100644 index 0000000000000..8546172f4f797 --- /dev/null +++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip \ +// RUN: --split-input-file %s | FileCheck %s + +spirv.module Logical GLSL450 requires + #spirv.vce { + + // CHECK-LABEL: @cooperative_matrix_length + spirv.func @cooperative_matrix_length() "None" { + // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB> + %0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB> + spirv.Return + } + + // CHECK-LABEL: @cooperative_matrix_load_1 + spirv.func @cooperative_matrix_load_1(%ptr : !spirv.ptr, %stride : i32) "None" { + // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, + // CHECK-SAME: : !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, : + !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + spirv.Return + } + + // CHECK-LABEL: @cooperative_matrix_load_2 + spirv.func @cooperative_matrix_load_2(%ptr : !spirv.ptr, %stride : i64) "None" { + // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, , + // CHECK-SAME: : !spirv.ptr, i64 -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : + !spirv.ptr, i64 -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc> + spirv.Return + } + + // CHECK-LABEL: @cooperative_matrix_store_1 + spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr, %stride : i32, + %m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" { + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, + // CHECK-SAME: : !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32 + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, : + !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32 + spirv.Return + } + + // CHECK-LABEL: @cooperative_matrix_store_2 + spirv.func @cooperative_matrix_store_2(%ptr : !spirv.ptr, %stride : i64, + %m : !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>) "None" { + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, , + // CHECK-SAME: : !spirv.ptr, !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>, i64 + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, , : + !spirv.ptr, !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>, i64 + spirv.Return + } + + // CHECK-LABEL: @cooperative_matrix_muladd + spirv.func @cooperative_matrix_muladd_1(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>, + %b : !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>, + %c : !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>) "None" { + // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : + // CHECK-SAME: !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>, + // CHECK-SAME: !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB> + // CHECK-SAME: -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc> + %p = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>, + !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB> + -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc> + + // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, : + // CHECK-SAME: !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>, + // CHECK-SAME: !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB> + // CHECK-SAME: -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc> + %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, + : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>, + !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB> + -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc> + + // TODO: Handle multiple matrix operands and add relevant testcases here. + spirv.Return + } + + // CHECK-LABEL: @cooperative_matrix_muladd + spirv.func @cooperative_matrix_muladd_2(%a : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>, + %b : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>, + %c : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>) "None" { + // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : + // CHECK-SAME: !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>, + // CHECK-SAME: !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB> + // CHECK-SAME: -> !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc> + %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>, + !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB> + -> !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc> + + spirv.Return + } + +} diff --git a/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir similarity index 100% rename from mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir rename to mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 8468f92600a44..ac00ddc6422c6 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -16,6 +16,7 @@ #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" @@ -512,6 +513,14 @@ static mlir::GenRegistration // Serialization AutoGen //===----------------------------------------------------------------------===// +// These enums are encoded as to constant values in SPIR-V blob, but we +// directly use the constant value as attribute in SPIR-V dialect. So need +// to handle them separately from normal enum attributes. +constexpr llvm::StringLiteral constantIdEnumAttrs[] = { + "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr", + "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr", + "SPIRV_MatrixLayoutAttr"}; + /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The /// generates code extracts the attribute with name `attrName` from /// `operandList` of `op`. @@ -521,12 +530,7 @@ static void emitAttributeSerialization(const Attribute &attr, StringRef attrName, raw_ostream &os) { os << tabs << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName); - if (attr.getAttrDefName() == "SPIRV_ScopeAttr" || - attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" || - attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") { - // These two enums are encoded as to constant values in SPIR-V blob, - // but we directly use the constant value as attribute in SPIR-V dialect. So - // need to handle them separately from normal enum attributes. + if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) { EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); os << tabs << formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), " @@ -557,11 +561,18 @@ static void emitAttributeSerialization(const Attribute &attr, " {0}.push_back(static_cast(" "llvm::cast(attr).getValue().getZExtValue()));\n", operandList); - } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") { + } else if (attr.isEnumAttr() || attr.isTypeAttr()) { + // It may be the first time this type appears in the IR, so we need to + // process it. + StringRef attrTypeID = "attrTypeID"; + os << tabs << formatv(" uint32_t {0} = 0;\n", attrTypeID); os << tabs - << formatv(" {0}.push_back(static_cast(" - "getTypeID(llvm::cast(attr).getValue())));\n", - operandList); + << formatv(" if (failed(processType({0}.getLoc(), " + "llvm::cast(attr).getValue(), {1}))) {{\n", + opVar, attrTypeID); + os << tabs << " return failure();\n"; + os << tabs << " }\n"; + os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList); } else { PrintFatalError( loc, @@ -816,12 +827,7 @@ static void emitAttributeDeserialization(const Attribute &attr, StringRef attrList, StringRef attrName, StringRef words, StringRef wordIndex, raw_ostream &os) { - if (attr.getAttrDefName() == "SPIRV_ScopeAttr" || - attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" || - attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") { - // These two enums are encoded as to constant values in SPIR-V blob, - // but we directly use the constant value as attribute in SPIR-V dialect. So - // need to handle them separately from normal enum attributes. + if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) { EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); os << tabs << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "