diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td index 7060aa80dc113..31d26706faecb 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -101,20 +101,32 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad" ``` {.ebnf} cooperative-matrix-load-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixLoad` ssa-use `,` ssa-use `,` - cooperative-matrix-layout `,` - (`[` memory-operand `]`)? ` : ` - pointer-type `as` cooperative-matrix-type + `<` cooperative-matrix-layout `>` + (`,` `<` memory-operand `>`)? `:` + pointer-type `,` stride-type `->` cooperative-matrix-type ``` + TODO: In the SPIR-V spec, `stride` is an optional argument. We should also + support this optionality in the SPIR-V dialect. + #### Example: ``` - %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor - : !spirv.ptr - as !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, + : !spirv.ptr, i32 + -> !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA> + + %1 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , + : !spirv.ptr, i64 + -> !spirv.KHR.coopmatrix<8x8xf32, Subgroup, MatrixAcc> ``` }]; + let assemblyFormat = [{ + $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:` + type(operands) `->` type($result) + }]; + let availability = [ MinVersion, MaxVersion, @@ -124,8 +136,8 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad" let arguments = (ins SPIRV_AnyPtr:$pointer, - SPIRV_Integer:$stride, SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout, + SPIRV_Integer:$stride, OptionalAttr:$memory_operand ); diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp index bc1d30f555183..4f986065d8d9c 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp @@ -37,49 +37,6 @@ LogicalResult KHRCooperativeMatrixLengthOp::verify() { // spirv.KHR.CooperativeMatrixLoad //===----------------------------------------------------------------------===// -ParseResult KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser, - OperationState &result) { - std::array operandInfo = {}; - if (parser.parseOperand(operandInfo[0]) || parser.parseComma()) - return failure(); - if (parser.parseOperand(operandInfo[1]) || parser.parseComma()) - return failure(); - - CooperativeMatrixLayoutKHR layout; - if (parseEnumKeywordAttr( - layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) { - return failure(); - } - - if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName)) - return failure(); - - Type ptrType; - Type elementType; - if (parser.parseColon() || parser.parseType(ptrType) || - parser.parseKeywordType("as", elementType)) { - return failure(); - } - result.addTypes(elementType); - - Type strideType = parser.getBuilder().getIntegerType(32); - if (parser.resolveOperands(operandInfo, {ptrType, strideType}, - parser.getNameLoc(), result.operands)) { - return failure(); - } - - return success(); -} - -void KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { - printer << " " << getPointer() << ", " << getStride() << ", " - << getMatrixLayout(); - // Print optional memory operand attribute. - if (auto memOperand = getMemoryOperand()) - printer << " [\"" << memOperand << "\"]"; - printer << " : " << getPointer().getType() << " as " << getType(); -} - static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, Type coopMatrix) { auto pointerType = cast(pointer); diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir index aa6e072b03c5d..aad1e44bf8f7b 100644 --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir @@ -23,37 +23,46 @@ spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" { // CHECK-LABEL: @cooperative_matrix_load spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32) "None" { - // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor : - // CHECK-SAME: !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> - %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor : - !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, : + // CHECK-SAME: !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, : + !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> spirv.Return } // CHECK-LABEL: @cooperative_matrix_load_memoperand spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr, %stride : i32) "None" { - // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] : - // CHECK-SAME: !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> - %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor ["Volatile"] : - !spirv.ptr as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, , : + // CHECK-SAME: !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : + !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> spirv.Return } // CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32) "None" { - // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor ["Volatile"] : - // CHECK-SAME: !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> - %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor ["Volatile"] : - !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, , : + // CHECK-SAME: !spirv.ptr, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : + !spirv.ptr, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> spirv.Return } // CHECK-LABEL: @cooperative_matrix_load_function spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr, %stride : i32) "None" { - // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor : - // CHECK-SAME: !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc> - %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor : - !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc> + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, : + // CHECK-SAME: !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, : + !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc> + spirv.Return +} + +// CHECK-LABEL: @cooperative_matrix_load_stride_i16 +spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr, %stride : i16) "None" { + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, : + // CHECK-SAME: !spirv.ptr, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, : + !spirv.ptr, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> spirv.Return } @@ -82,8 +91,8 @@ spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32) "None" { // expected-error @+1 {{Pointer must point to a scalar or vector type}} - %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor : - !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, : + !spirv.ptr, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> spirv.Return } @@ -92,16 +101,16 @@ spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr, %stride : i32) "None" { // expected-error @+1 {{expected ','}} %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride : - !spirv.ptr as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> + !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> spirv.Return } // ----- spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr, %stride : i32) "None" { - // expected-error @+1 {{expected valid keyword}} + // expected-error @+1 {{expected '<'}} %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, : - !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA> + !spirv.ptr, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA> spirv.Return } @@ -109,8 +118,8 @@ spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr, %stride : i32) "None" { // expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}} - %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor : - !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, : + !spirv.ptr, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup> spirv.Return } diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index ccf4240f8e560..8468f92600a44 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -929,9 +929,9 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef loc, if (auto *valueArg = llvm::dyn_cast_if_present(argument)) { if (valueArg->isVariableLength()) { if (i != e - 1) { - PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or " - "std::optional<...> arguments only if " - "it's the last argument"); + PrintFatalError( + loc, "SPIR-V ops can have Variadic<..> or " + "Optional<...> arguments only if it's the last argument"); } os << tabs << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);