diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h index cd650345f1daa..d34549432161d 100644 --- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h +++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h @@ -31,16 +31,10 @@ void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns); -/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, -/// using the NV Cooperative Matrix extension. -void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns( - SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns); - -/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix type -/// conversion to the type converter. Defaults to KHR cooperative matrix types. -/// When `useNVTypes` is `true`, uses the NV cooperative matrix types. +/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix KHR type +/// conversion to the type converter. void populateMMAToSPIRVCoopMatrixTypeConversion( - SPIRVTypeConverter &typeConverter, bool useNVTypes = false); + SPIRVTypeConverter &typeConverter); } // namespace mlir #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 6193aeb545bc6..71be8841ca7c0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -564,10 +564,6 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> { Option<"use64bitIndex", "use-64bit-index", "bool", /*default=*/"false", "Use 64-bit integers to convert index types">, - Option<"useCoopMatrixNV", "use-coop-matrix-nv", - "bool", /*default=*/"false", - "Use the NV cooperative matrix extension insted of the KHR extension" - " to lower GPU WMMA ops">, ]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index ee1fbba1e2844..6ec97e17c5dcc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -1253,12 +1253,6 @@ def SPIRV_C_RayTracingProvisionalKHR : I32EnumAttrCase<"RayTr Extension<[SPV_KHR_ray_tracing]> ]; } -def SPIRV_C_CooperativeMatrixNV : I32EnumAttrCase<"CooperativeMatrixNV", 5357> { - list implies = [SPIRV_C_Shader]; - list availability = [ - Extension<[SPV_NV_cooperative_matrix]> - ]; -} def SPIRV_C_FragmentShaderSampleInterlockEXT : I32EnumAttrCase<"FragmentShaderSampleInterlockEXT", 5363> { list implies = [SPIRV_C_Shader]; list availability = [ @@ -1501,7 +1495,7 @@ def SPIRV_CapabilityAttr : SPIRV_C_ShaderNonUniform, SPIRV_C_RuntimeDescriptorArray, SPIRV_C_StorageTexelBufferArrayDynamicIndexing, SPIRV_C_RayTracingNV, SPIRV_C_RayTracingMotionBlurNV, SPIRV_C_PhysicalStorageBufferAddresses, - SPIRV_C_RayTracingProvisionalKHR, SPIRV_C_CooperativeMatrixNV, + SPIRV_C_RayTracingProvisionalKHR, SPIRV_C_FragmentShaderSampleInterlockEXT, SPIRV_C_FragmentShaderShadingRateInterlockEXT, SPIRV_C_ShaderSMBuiltinsNV, SPIRV_C_FragmentShaderPixelInterlockEXT, SPIRV_C_DemoteToHelperInvocation, @@ -4123,8 +4117,6 @@ class SignlessOrUnsignedIntOfWidths widths> : def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">; def SPIRV_IsCooperativeMatrixType : CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">; -def SPIRV_IsCooperativeMatrixNVType : - CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">; def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">; def SPIRV_IsJointMatrixType : CPred<"::llvm::isa<::mlir::spirv::JointMatrixINTELType>($_self)">; @@ -4157,9 +4149,6 @@ def SPIRV_AnyArray : DialectType; -def SPIRV_AnyCooperativeMatrixNV : DialectType; def SPIRV_AnyImage : DialectType; def SPIRV_AnyJointMatrix : DialectType; def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>; def SPIRV_Composite : AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, - SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV, - SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>; + SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>; def SPIRV_Type : AnyTypeOf<[ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector, SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, - SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV, - SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage + SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, + SPIRV_AnySampledImage ]>; def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; @@ -4195,11 +4183,6 @@ class SPIRV_CoopMatrixOfType allowedTypes> : "::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()", "Cooperative Matrix">; -class SPIRV_CoopMatrixNVOfType allowedTypes> : - ContainerType, SPIRV_IsCooperativeMatrixNVType, - "::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()", - "Cooperative Matrix NV">; - class SPIRV_JointMatrixOfType allowedTypes> : ContainerType, SPIRV_IsJointMatrixType, "::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()", @@ -4213,12 +4196,11 @@ class SPIRV_ScalarOrVectorOf : class SPIRV_ScalarOrVectorOrCoopMatrixOf : AnyTypeOf<[type, SPIRV_VectorOf, - SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>; + SPIRV_CoopMatrixOfType<[type]>]>; class SPIRV_MatrixOrCoopMatrixOf : AnyTypeOf<[SPIRV_AnyMatrix, - SPIRV_CoopMatrixOfType<[type]>, - SPIRV_CoopMatrixNVOfType<[type]>]>; + SPIRV_CoopMatrixOfType<[type]>]>; def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>; def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>; @@ -4480,11 +4462,6 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrix def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>; def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>; def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>; -def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>; -def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>; -def SPIRV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>; -def SPIRV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>; -def SPIRV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>; def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>; def SPIRV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>; def SPIRV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>; @@ -4585,9 +4562,6 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR, - SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV, - SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV, - SPIRV_OC_OpCooperativeMatrixLengthNV, SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td index 29ad45bddd552..46732ba19afed 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -338,253 +338,6 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul ]; } -//===----------------------------------------------------------------------===// -// SPV_NV_cooperative_matrix extension ops. -//===----------------------------------------------------------------------===// - -// ----- - -def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLength", - [Pure]> { - let summary = "See extension SPV_NV_cooperative_matrix"; - - let description = [{ - Number of components of a cooperative matrix type accessible to each - invocation when treated as a composite. - - Result Type must be an OpTypeInt with 32-bit Width and 0 Signedness. - - Type is a cooperative matrix type. - - #### Example: - - ``` - %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - ``` - }]; - - let assemblyFormat = "attr-dict `:` $cooperative_matrix_type"; - - let availability = [ - MinVersion, - MaxVersion, - Extension<[SPV_NV_cooperative_matrix]>, - Capability<[SPIRV_C_CooperativeMatrixNV]> - ]; - - let arguments = (ins - TypeAttr:$cooperative_matrix_type - ); - - let results = (outs - SPIRV_Int32:$result - ); -} - -// ----- - -def SPIRV_NVCooperativeMatrixLoadOp : SPIRV_NvVendorOp<"CooperativeMatrixLoad", []> { - let summary = "See extension SPV_NV_cooperative_matrix"; - - let description = [{ - Load a cooperative matrix through a pointer. - - Result Type is the type of the loaded object. It must be a cooperative - matrix type. - - Pointer is a pointer into an array. Its type must be an OpTypePointer whose - Type operand is a scalar or vector type. The storage class of Pointer must - be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is - supported) PhysicalStorageBufferEXT. - - Stride is the number of elements in the array in memory between the first - component of consecutive rows (or columns) in the result. It must be a - scalar integer type. - - ColumnMajor indicates whether the values loaded from memory are arranged in - column-major or row-major order. It must be a boolean constant instruction, - with false indicating row major and true indicating column major. - - Memory Access must be a Memory Access literal. If not present, it is the - same as specifying None. - - If ColumnMajor is false, then elements (row,*) of the result are taken in - order from contiguous locations starting at Pointer[row*Stride]. If - ColumnMajor is true, then elements (*,col) of the result are taken in order - from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride - decoration on Pointer is ignored. - - For a given dynamic instance of this instruction, all operands of this - instruction must be the same for all invocations in a given scope instance - (where the scope is the scope the cooperative matrix type was created with). - All invocations in a given scope instance must be active or all must be - inactive. - - ### Custom assembly form - - ``` {.ebnf} - cooperative-matrixload-op ::= ssa-id `=` `spirv.NV.CooperativeMatrixLoad` - ssa-use `,` ssa-use `,` ssa-use - (`[` memory-access `]`)? ` : ` - pointer-type `as` - cooperative-matrix-type - ``` - - #### Example: - - ``` - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %colMajor - : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - ``` - }]; - - let availability = [ - MinVersion, - MaxVersion, - Extension<[SPV_NV_cooperative_matrix]>, - Capability<[SPIRV_C_CooperativeMatrixNV]> - ]; - - let arguments = (ins - SPIRV_AnyPtr:$pointer, - SPIRV_Integer:$stride, - SPIRV_Bool:$columnmajor, - OptionalAttr:$memory_access - ); - - let results = (outs - SPIRV_AnyCooperativeMatrixNV:$result - ); -} - -// ----- - -def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAdd", - [Pure, AllTypesMatch<["c", "result"]>]> { - let summary = "See extension SPV_NV_cooperative_matrix"; - - let description = [{ - Linear-algebraic matrix multiply of A by B and then component-wise add C. - The order of the operations is implementation-dependent. The internal - precision of floating-point operations is defined by the client API. - Integer operations are performed at the precision of the Result Type and are - exact unless there is overflow or underflow, in which case the result is - undefined. - - Result Type must be a cooperative matrix type with M rows and N columns. - - A is a cooperative matrix with M rows and K columns. - - B is a cooperative matrix with K rows and N columns. - - C is a cooperative matrix with M rows and N columns. - - The values of M, N, and K must be consistent across the result and operands. - This is referred to as an MxNxK matrix multiply. - - A, B, C, and Result Type must have the same scope, and this defines the - scope of the operation. A, B, C, and Result Type need not necessarily have - the same component type, this is defined by the client API. - - If the Component Type of any matrix operand is an integer type, then its - components are treated as signed if its Component Type has Signedness of 1 - and are treated as unsigned otherwise. - - For a given dynamic instance of this instruction, all invocations in a given - scope instance must be active or all must be inactive (where the scope is - the scope of the operation). - - #### Example: - - ``` - %0 = spirv.NV.CooperativeMatrixMulAdd %arg0, %arg1, %arg2, : - !spirv.NV.coopmatrix<8x16xi32, Subgroup> - ``` - }]; - - let assemblyFormat = [{ - operands attr-dict `:` type($a) `,` type($b) `->` type($c) - }]; - - let availability = [ - MinVersion, - MaxVersion, - Extension<[SPV_NV_cooperative_matrix]>, - Capability<[SPIRV_C_CooperativeMatrixNV]> - ]; - - let arguments = (ins - SPIRV_AnyCooperativeMatrixNV:$a, - SPIRV_AnyCooperativeMatrixNV:$b, - SPIRV_AnyCooperativeMatrixNV:$c - ); - - let results = (outs - SPIRV_AnyCooperativeMatrixNV:$result - ); -} - -// ----- - -def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore", []> { - let summary = "See extension SPV_NV_cooperative_matrix"; - - let description = [{ - Store a cooperative matrix through a pointer. - - Pointer is a pointer into an array. Its type must be an OpTypePointer whose - Type operand is a scalar or vector type. The storage class of Pointer must - be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is - supported) PhysicalStorageBufferEXT. - - Object is the object to store. Its type must be an - OpTypeCooperativeMatrixNV. - - Stride is the number of elements in the array in memory between the first - component of consecutive rows (or columns) in the result. It must be a - scalar integer type. - - ColumnMajor indicates whether the values stored to memory are arranged in - column-major or row-major order. It must be a boolean constant instruction, - with false indicating row major and true indicating column major. - - Memory Access must be a Memory Access literal. If not present, it is the - same as specifying None. - - ``` {.ebnf} - coop-matrix-store-op ::= `spirv.NV.CooperativeMatrixStore ` - ssa-use `, ` ssa-use `, ` - ssa-use `, ` ssa-use `, ` - (`[` memory-access `]`)? `:` - pointer-type `,` coop-matrix-type - ``` - - #### Example: - - ``` - spirv.NV.CooperativeMatrixStore %arg0, %arg2, %arg1, %arg3 : - !spirv.ptr, !spirv.NV.coopmatrix<16x8xi32, Workgroup> - ``` - }]; - - let availability = [ - MinVersion, - MaxVersion, - Extension<[SPV_NV_cooperative_matrix]>, - Capability<[SPIRV_C_CooperativeMatrixNV]> - ]; - - let arguments = (ins - SPIRV_AnyPtr:$pointer, - SPIRV_AnyCooperativeMatrixNV:$object, - SPIRV_Integer:$stride, - SPIRV_Bool:$columnmajor, - OptionalAttr:$memory_access - ); - - let results = (outs); -} - // ----- #endif // MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index d946d936d4e6c..55f0c787b4440 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -29,7 +29,6 @@ namespace spirv { namespace detail { struct ArrayTypeStorage; struct CooperativeMatrixTypeStorage; -struct CooperativeMatrixNVTypeStorage; struct ImageTypeStorage; struct JointMatrixTypeStorage; struct MatrixTypeStorage; @@ -421,32 +420,6 @@ class CooperativeMatrixType std::optional storage = std::nullopt); }; -// SPIR-V NV cooperative matrix type -class CooperativeMatrixNVType - : public Type::TypeBase { -public: - using Base::Base; - - static constexpr StringLiteral name = "spirv.NV.coopmatrix"; - - static CooperativeMatrixNVType get(Type elementType, Scope scope, - unsigned rows, unsigned columns); - Type getElementType() const; - - /// Returns the scope of the matrix. - Scope getScope() const; - /// Returns the number of rows of the matrix. - unsigned getRows() const; - /// Returns the number of columns of the matrix. - unsigned getColumns() const; - - void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage = std::nullopt); - void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage = std::nullopt); -}; - // SPIR-V joint matrix type class JointMatrixINTELType : public Type::TypeBaseuse64bitIndex; SPIRVTypeConverter typeConverter(targetAttr, options); - populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter, - this->useCoopMatrixNV); + populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter); RewritePatternSet patterns(context); populateGPUToSPIRVPatterns(typeConverter, patterns); - if (this->useCoopMatrixNV) { - populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter, - patterns); - } else { - populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter, - patterns); - } + populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter, + patterns); // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index 4a4281aaaf0db..92cc0eadb9784 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -32,19 +32,18 @@ namespace mlir { //===----------------------------------------------------------------------===// -// Patterns and helpers used by both the KHR and the NV lowering paths. +// Patterns and helpers. //===----------------------------------------------------------------------===// /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op /// when the elementwise op directly supports with cooperative matrix type. /// Returns false if cannot. /// -/// See SPV_NV_cooperative_matrix for supported elementwise ops. +/// See SPV_KHR_cooperative_matrix for supported elementwise ops. static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands) { - assert((isa( - coopType))); + assert((isa(coopType))); switch (op.getOpType()) { case gpu::MMAElementwiseOp::ADDF: @@ -89,8 +88,7 @@ bool allOperandsHaveSameCoopMatrixType(ValueRange operands) { llvm::map_range(operands, [](Value v) { return v.getType(); }))) return false; - return isa( - operands.front().getType()); + return isa(operands.front().getType()); } namespace { @@ -292,104 +290,6 @@ struct WmmaMmaOpToSPIRVLowering final } // namespace } // namespace khr - -//===----------------------------------------------------------------------===// -// SPV_NV_cooperative_matrix -//===----------------------------------------------------------------------===// - -namespace nv { -namespace { - -/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV -/// dialect. -struct WmmaLoadOpToSPIRVLowering final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = subgroupMmaLoadMatrixOp->getLoc(); - auto &typeConverter = *getTypeConverter(); - - gpu::MMAMatrixType retType = - cast(subgroupMmaLoadMatrixOp.getRes().getType()); - auto memrefType = - cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()); - Value bufferPtr = - spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(), - adaptor.getIndices(), loc, rewriter); - auto coopType = - typeConverter.convertType(retType); - if (!coopType) - return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp, - "type conversion failed"); - - int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue(); - auto i32Type = rewriter.getI32Type(); - auto strideValue = rewriter.create( - loc, i32Type, IntegerAttr::get(i32Type, stride)); - bool isColMajor = static_cast(subgroupMmaLoadMatrixOp.getTranspose()); - auto columnMajor = rewriter.create( - loc, rewriter.getI1Type(), rewriter.getBoolAttr(isColMajor)); - rewriter.replaceOpWithNewOp( - subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor, - spirv::MemoryAccessAttr()); - return success(); - } -}; - -/// Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV -/// dialect. -struct WmmaStoreOpToSPIRVLowering final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = subgroupMmaStoreMatrixOp->getLoc(); - auto memrefType = - cast(subgroupMmaStoreMatrixOp.getDstMemref().getType()); - Value bufferPtr = spirv::getElementPtr( - *getTypeConverter(), memrefType, - adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter); - int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue(); - auto i32Type = rewriter.getI32Type(); - auto strideValue = rewriter.create( - loc, i32Type, IntegerAttr::get(i32Type, stride)); - bool useColMajor = - static_cast(subgroupMmaStoreMatrixOp.getTranspose()); - auto columnMajor = rewriter.create( - loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor)); - rewriter.replaceOpWithNewOp( - subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue, - columnMajor, spirv::MemoryAccessAttr()); - return success(); - } -}; - -/// Converts GPU MMA Compute to -/// NVCooperativeMatrixMulAdd op in the SPIRV dialect. -struct WmmaMmaOpToSPIRVLowering final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(), - adaptor.getOpB(), adaptor.getOpC()); - return success(); - } -}; - -} // namespace -} // namespace nv } // namespace mlir void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( @@ -404,31 +304,8 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( /*benefit=*/2); } -void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns( - SPIRVTypeConverter &converter, RewritePatternSet &patterns) { - using namespace mlir; - MLIRContext *context = patterns.getContext(); - patterns.add(converter, context); - // Give the following patterns higher benefit to prevail over the default one. - patterns.add(converter, context, - /*benefit=*/2); -} - void mlir::populateMMAToSPIRVCoopMatrixTypeConversion( - mlir::SPIRVTypeConverter &typeConverter, bool useNVTypes) { - if (useNVTypes) { - typeConverter.addConversion([](gpu::MMAMatrixType type) { - ArrayRef retTypeShape = type.getShape(); - Type elementType = type.getElementType(); - return spirv::CooperativeMatrixNVType::get( - elementType, spirv::Scope::Subgroup, retTypeShape[0], - retTypeShape[1]); - }); - return; - } - + mlir::SPIRVTypeConverter &typeConverter) { typeConverter.addConversion([](gpu::MMAMatrixType type) { ArrayRef retTypeShape = type.getShape(); Type elementType = type.getElementType(); diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp index f24da2ca5c3f2..52b4380ed27f7 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp @@ -37,7 +37,7 @@ static LogicalResult verifyCastOp(Operation *op, auto [operandElemTy, resultElemTy] = TypeSwitch(operandType) .Case( + spirv::JointMatrixINTELType>( [resultType](auto concreteOperandTy) -> TypePair { if (auto concreteResultTy = dyn_cast(resultType)) { diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp index c8b274ceec3e5..d532d466334a5 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp @@ -136,156 +136,4 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// spirv.NV.CooperativeMatrixLength -//===----------------------------------------------------------------------===// - -LogicalResult NVCooperativeMatrixLengthOp::verify() { - if (!isa(getCooperativeMatrixType())) { - return emitOpError( - "type attribute must be a '!spirv.NV.coopmatrix' type, found ") - << getCooperativeMatrixType() << " instead"; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.NV.CooperativeMatrixLoad -//===----------------------------------------------------------------------===// - -ParseResult NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser, - OperationState &result) { - SmallVector operandInfo; - Type strideType = parser.getBuilder().getIntegerType(32); - Type columnMajorType = parser.getBuilder().getIntegerType(1); - Type ptrType; - Type elementType; - if (parser.parseOperandList(operandInfo, 3) || - parseMemoryAccessAttributes(parser, result) || parser.parseColon() || - parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) { - return failure(); - } - if (parser.resolveOperands(operandInfo, - {ptrType, strideType, columnMajorType}, - parser.getNameLoc(), result.operands)) { - return failure(); - } - - result.addTypes(elementType); - return success(); -} - -void NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { - printer << " " << getPointer() << ", " << getStride() << ", " - << getColumnmajor(); - // Print optional memory access attribute. - if (auto memAccess = getMemoryAccess()) - printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; - printer << " : " << getPointer().getType() << " as " << getType(); -} - -static LogicalResult -verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) { - Type pointeeType = llvm::cast(pointer).getPointeeType(); - if (!llvm::isa(pointeeType) && - !llvm::isa(pointeeType)) - return op->emitError( - "Pointer must point to a scalar or vector type but provided ") - << pointeeType; - StorageClass storage = llvm::cast(pointer).getStorageClass(); - if (storage != StorageClass::Workgroup && - storage != StorageClass::StorageBuffer && - storage != StorageClass::PhysicalStorageBuffer) - return op->emitError( - "Pointer storage class must be Workgroup, StorageBuffer or " - "PhysicalStorageBufferEXT but provided ") - << stringifyStorageClass(storage); - return success(); -} - -LogicalResult NVCooperativeMatrixLoadOp::verify() { - return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(), - getResult().getType()); -} - -//===----------------------------------------------------------------------===// -// spirv.NV.CooperativeMatrixStore -//===----------------------------------------------------------------------===// - -ParseResult NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser, - OperationState &result) { - SmallVector operandInfo; - Type strideType = parser.getBuilder().getIntegerType(32); - Type columnMajorType = parser.getBuilder().getIntegerType(1); - Type ptrType; - Type elementType; - if (parser.parseOperandList(operandInfo, 4) || - parseMemoryAccessAttributes(parser, result) || parser.parseColon() || - parser.parseType(ptrType) || parser.parseComma() || - parser.parseType(elementType)) { - return failure(); - } - if (parser.resolveOperands( - operandInfo, {ptrType, elementType, strideType, columnMajorType}, - parser.getNameLoc(), result.operands)) { - return failure(); - } - - return success(); -} - -void NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { - printer << " " << getPointer() << ", " << getObject() << ", " << getStride() - << ", " << getColumnmajor(); - // Print optional memory access attribute. - if (auto memAccess = getMemoryAccess()) - printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; - printer << " : " << getPointer().getType() << ", " << getOperand(1).getType(); -} - -LogicalResult NVCooperativeMatrixStoreOp::verify() { - return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(), - getObject().getType()); -} - -//===----------------------------------------------------------------------===// -// spirv.NV.CooperativeMatrixMulAdd -//===----------------------------------------------------------------------===// - -static LogicalResult verifyCoopMatrixMulAddNV(NVCooperativeMatrixMulAddOp op) { - if (op.getC().getType() != op.getResult().getType()) - return op.emitOpError("result and third operand must have the same type"); - auto typeA = llvm::cast(op.getA().getType()); - auto typeB = llvm::cast(op.getB().getType()); - auto typeC = llvm::cast(op.getC().getType()); - auto typeR = llvm::cast(op.getResult().getType()); - if (typeA.getRows() != typeR.getRows() || - typeA.getColumns() != typeB.getRows() || - typeB.getColumns() != typeR.getColumns()) - return op.emitOpError("matrix size must match"); - if (typeR.getScope() != typeA.getScope() || - typeR.getScope() != typeB.getScope() || - typeR.getScope() != typeC.getScope()) - return op.emitOpError("matrix scope must match"); - auto elementTypeA = typeA.getElementType(); - auto elementTypeB = typeB.getElementType(); - if (isa(elementTypeA) && isa(elementTypeB)) { - if (llvm::cast(elementTypeA).getWidth() != - llvm::cast(elementTypeB).getWidth()) - return op.emitOpError( - "matrix A and B integer element types must be the same bit width"); - } else if (elementTypeA != elementTypeB) { - return op.emitOpError( - "matrix A and B non-integer element types must match"); - } - if (typeR.getElementType() != typeC.getElementType()) - return op.emitOpError("matrix accumulator element type must match"); - return success(); -} - -LogicalResult NVCooperativeMatrixMulAddOp::verify() { - return verifyCoopMatrixMulAddNV(*this); -} - } // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 8a68decc5878c..9d4d1aec36709 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -360,37 +360,6 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use); } -// nv-cooperative-matrix-type ::= -// `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type `,` scope `>` -static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect, - DialectAsmParser &parser) { - if (parser.parseLess()) - return Type(); - - SmallVector dims; - SMLoc countLoc = parser.getCurrentLocation(); - if (parser.parseDimensionList(dims, /*allowDynamic=*/false)) - return Type(); - - if (dims.size() != 2) { - parser.emitError(countLoc, "expected rows and columns size"); - return Type(); - } - - auto elementTy = parseAndVerifyType(dialect, parser); - if (!elementTy) - return Type(); - - Scope scope; - if (parser.parseComma() || - spirv::parseEnumKeywordAttr(scope, parser, "scope ")) - return Type(); - - if (parser.parseGreater()) - return Type(); - return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]); -} - // joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x` // element-type // `,` layout `,` scope`>` @@ -810,8 +779,6 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const { return parseArrayType(*this, parser); if (keyword == "coopmatrix") return parseCooperativeMatrixType(*this, parser); - if (keyword == "NV.coopmatrix") - return parseCooperativeMatrixNVType(*this, parser); if (keyword == "jointmatrix") return parseJointMatrixType(*this, parser); if (keyword == "image") @@ -917,12 +884,6 @@ static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { << type.getUse() << ">"; } -static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { - os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; - os << type.getElementType() << ", " << stringifyScope(type.getScope()); - os << ">"; -} - static void print(JointMatrixINTELType type, DialectAsmPrinter &os) { os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; os << type.getElementType() << ", " @@ -937,10 +898,9 @@ static void print(MatrixType type, DialectAsmPrinter &os) { void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) - .Case( - [&](auto type) { print(type, os); }) + .Case([&](auto type) { print(type, os); }) .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 2a1d083308282..dc558b878b3b7 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -374,8 +374,7 @@ LogicalResult spirv::CompositeConstructOp::verify() { auto coopElementType = llvm::TypeSwitch(getType()) - .Case( + .Case( [](auto coopType) { return coopType.getElementType(); }) .Default([](Type) { return nullptr; }); @@ -1611,8 +1610,7 @@ LogicalResult spirv::VectorShuffleOp::verify() { LogicalResult spirv::MatrixTimesScalarOp::verify() { Type elementType = llvm::TypeSwitch(getMatrix().getType()) - .Case( + .Case( [](auto matrixType) { return matrixType.getElementType(); }) .Default([](Type) { return nullptr; }); @@ -1751,7 +1749,7 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() { return emitError("result type must be a composite type, but provided ") << getType(); - if (llvm::isa(cType)) + if (llvm::isa(cType)) return emitError("unsupported composite type ") << cType; if (llvm::isa(cType)) return emitError("unsupported composite type ") << cType; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index f1bac6490837b..3f25696aa5eb6 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -95,9 +95,8 @@ bool CompositeType::classof(Type type) { if (auto vectorType = llvm::dyn_cast(type)) return isValid(vectorType); return llvm::isa(type); + spirv::JointMatrixINTELType, spirv::MatrixType, + spirv::RuntimeArrayType, spirv::StructType>(type); } bool CompositeType::isValid(VectorType type) { @@ -108,8 +107,8 @@ bool CompositeType::isValid(VectorType type) { Type CompositeType::getElementType(unsigned index) const { return TypeSwitch(*this) - .Case( + .Case( [](auto type) { return type.getElementType(); }) .Case([](MatrixType type) { return type.getColumnType(); }) .Case( @@ -127,7 +126,7 @@ unsigned CompositeType::getNumElements() const { return structType.getNumElements(); if (auto vectorType = llvm::dyn_cast(*this)) return vectorType.getNumElements(); - if (llvm::isa(*this)) { + if (llvm::isa(*this)) { llvm_unreachable( "invalid to query number of elements of spirv Cooperative Matrix type"); } @@ -143,16 +142,16 @@ unsigned CompositeType::getNumElements() const { } bool CompositeType::hasCompileTimeKnownNumElements() const { - return !llvm::isa(*this); + return !llvm::isa(*this); } void CompositeType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { TypeSwitch(*this) - .Case( + .Case( [&](auto type) { type.getExtensions(extensions, storage); }) .Case([&](VectorType type) { return llvm::cast(type.getElementType()) @@ -165,8 +164,8 @@ void CompositeType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage) { TypeSwitch(*this) - .Case( + .Case( [&](auto type) { type.getCapabilities(capabilities, storage); }) .Case([&](VectorType type) { auto vecSize = getNumElements(); @@ -267,70 +266,6 @@ void CooperativeMatrixType::getCapabilities( capabilities.push_back(caps); } -//===----------------------------------------------------------------------===// -// CooperativeMatrixNVType -//===----------------------------------------------------------------------===// - -struct spirv::detail::CooperativeMatrixNVTypeStorage : public TypeStorage { - using KeyTy = std::tuple; - - static CooperativeMatrixNVTypeStorage * - construct(TypeStorageAllocator &allocator, const KeyTy &key) { - return new (allocator.allocate()) - CooperativeMatrixNVTypeStorage(key); - } - - bool operator==(const KeyTy &key) const { - return key == KeyTy(elementType, scope, rows, columns); - } - - CooperativeMatrixNVTypeStorage(const KeyTy &key) - : elementType(std::get<0>(key)), rows(std::get<2>(key)), - columns(std::get<3>(key)), scope(std::get<1>(key)) {} - - Type elementType; - unsigned rows; - unsigned columns; - Scope scope; -}; - -CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType, - Scope scope, unsigned rows, - unsigned columns) { - return Base::get(elementType.getContext(), elementType, scope, rows, columns); -} - -Type CooperativeMatrixNVType::getElementType() const { - return getImpl()->elementType; -} - -Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; } - -unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; } - -unsigned CooperativeMatrixNVType::getColumns() const { - return getImpl()->columns; -} - -void CooperativeMatrixNVType::getExtensions( - SPIRVType::ExtensionArrayRefVector &extensions, - std::optional storage) { - llvm::cast(getElementType()).getExtensions(extensions, storage); - static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix}; - ArrayRef ref(exts, std::size(exts)); - extensions.push_back(ref); -} - -void CooperativeMatrixNVType::getCapabilities( - SPIRVType::CapabilityArrayRefVector &capabilities, - std::optional storage) { - llvm::cast(getElementType()) - .getCapabilities(capabilities, storage); - static const Capability caps[] = {Capability::CooperativeMatrixNV}; - ArrayRef ref(caps, std::size(caps)); - capabilities.push_back(ref); -} - //===----------------------------------------------------------------------===// // JointMatrixType //===----------------------------------------------------------------------===// @@ -1312,7 +1247,7 @@ void MatrixType::getCapabilities( //===----------------------------------------------------------------------===// void SPIRVDialect::registerTypes() { - addTypes(); + addTypes(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index 954aaa98c3299..a678124bf4832 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -165,7 +165,6 @@ LogicalResult spirv::Deserializer::processInstruction( case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: case spirv::Opcode::OpTypeCooperativeMatrixKHR: - case spirv::Opcode::OpTypeCooperativeMatrixNV: return processType(opcode, operands); case spirv::Opcode::OpTypeForwardPointer: return processTypeForwardPointer(operands); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 89e2e7ad52fa7..948dcfb4885b3 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -778,8 +778,6 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, return processArrayType(operands); case spirv::Opcode::OpTypeCooperativeMatrixKHR: return processCooperativeMatrixTypeKHR(operands); - case spirv::Opcode::OpTypeCooperativeMatrixNV: - return processCooperativeMatrixTypeNV(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); case spirv::Opcode::OpTypeJointMatrixINTEL: @@ -955,37 +953,6 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR( return success(); } -LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV( - ArrayRef operands) { - if (operands.size() != 5) { - return emitError(unknownLoc, "OpTypeCooperativeMatrixNV must have element " - "type and row x column parameters"); - } - - Type elementTy = getType(operands[1]); - if (!elementTy) { - return emitError(unknownLoc, - "OpTypeCooperativeMatrixNV references undefined ") - << operands[1]; - } - - std::optional scope = - spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); - if (!scope) { - return emitError( - unknownLoc, - "OpTypeCooperativeMatrixNV references undefined scope ") - << operands[2]; - } - - unsigned rows = getConstantInt(operands[3]).getInt(); - unsigned columns = getConstantInt(operands[4]).getInt(); - - typeMap[operands[0]] = - spirv::CooperativeMatrixNVType::get(elementTy, *scope, rows, columns); - return success(); -} - LogicalResult spirv::Deserializer::processJointMatrixType(ArrayRef operands) { if (operands.size() != 6) { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 9e9a16456cc10..08395dd4cf522 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -632,26 +632,6 @@ LogicalResult Serializer::prepareBasicType( return success(); } - if (auto cooperativeMatrixType = - dyn_cast(type)) { - uint32_t elementTypeID = 0; - if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), - elementTypeID, serializationCtx))) { - return failure(); - } - typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; - auto getConstantOp = [&](uint32_t id) { - auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); - return prepareConstantInt(loc, attr); - }; - llvm::append_values( - operands, elementTypeID, - getConstantOp(static_cast(cooperativeMatrixType.getScope())), - getConstantOp(cooperativeMatrixType.getRows()), - getConstantOp(cooperativeMatrixType.getColumns())); - return success(); - } - if (auto jointMatrixType = dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, jointMatrixType.getElementType(), diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir index f129cc8ce84ec..477f344b1ae5f 100644 --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \ +// RUN: mlir-opt --convert-gpu-to-spirv --cse \ // RUN: --split-input-file --verify-diagnostics %s | FileCheck %s module attributes { diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir deleted file mode 100644 index ec7da92704c07..0000000000000 --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir +++ /dev/null @@ -1,194 +0,0 @@ -// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=true" \ -// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { - gpu.module @kernels { - // CHECK-LABEL: spirv.func @gpu_wmma_load_op - // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> - gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class>) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - %i = arith.constant 16 : index - %j = arith.constant 16 : index - // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr as !spirv.NV.coopmatrix<16x16xf16, Subgroup> - %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class> -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: spirv.Return - gpu.return - } - } -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { - gpu.module @kernels { - // CHECK-LABEL: spirv.func @gpu_wmma_load_op_transpose - // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi - gpu.func @gpu_wmma_load_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class>) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - %i = arith.constant 16 : index - %j = arith.constant 16 : index - // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr as !spirv.NV.coopmatrix<16x16xf16, Subgroup> - %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, #spirv.storage_class> -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: spirv.Return - gpu.return - } - } -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { - gpu.module @kernels { - // CHECK-LABEL: spirv.func @gpu_wmma_store_op - // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> - // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> - gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - %i = arith.constant 16 : index - %j = arith.constant 16 : index - // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr, !spirv.NV.coopmatrix<16x16xf16, Subgroup> - gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class> - // CHECK: spirv.Return - gpu.return - } - } -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { - gpu.module @kernels { - // CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose - // CHECK-SAME: {{%.*}}: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} - // CHECK-SAME: {{%.*}}: !spirv.NV.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) - // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi - gpu.func @gpu_wmma_store_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - %i = arith.constant 16 : index - %j = arith.constant 16 : index - // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr, !spirv.NV.coopmatrix<16x16xf16, Subgroup> - gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class> - // CHECK: spirv.Return - gpu.return - } - } -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { - gpu.module @kernels { - // CHECK-LABEL: spirv.func @gpu_wmma_mma_op - // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> - // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> - // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> - gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, !spirv.NV.coopmatrix<16x16xf16, Subgroup> -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> - %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: spirv.Return - gpu.return - } - } -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { - gpu.module @kernels { - // CHECK-LABEL: spirv.func @gpu_wmma_constant_op - gpu.func @gpu_wmma_constant_op() kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - // CHECK: {{%.*}} = spirv.Constant - %cst = arith.constant 1.0 : f16 - // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> - %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: spirv.Return - gpu.return - } - } -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { - gpu.module @kernels { - // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default - // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> - // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup> - gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> - %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> - %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> - %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> to !spirv.NV.coopmatrix<16x16xf32, Subgroup> - %F = gpu.subgroup_mma_elementwise extf %E : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> - // CHECK: spirv.Return - gpu.return - } - } -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { - gpu.module @kernels { - // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar - // CHECK-SAME: %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup> - // CHECK-SAME: %[[S:.+]]: f16 - gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 - %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 - %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: spirv.Return - gpu.return - } - } -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { - gpu.module @kernels { - // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar - // CHECK-SAME: %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup> - // CHECK-SAME: %[[S:.+]]: f16 - gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> - %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> - // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup> - %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> - gpu.return - } - } -} diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir index 4f4a72da7c050..aaee2ccd3cb8c 100644 --- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir @@ -146,14 +146,6 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou // ----- -func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) { - // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// ----- - //===----------------------------------------------------------------------===// // spirv.ConvertSToF //===----------------------------------------------------------------------===// @@ -238,14 +230,6 @@ func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, M // ----- -func.func @f_convert_coop_matrix_nv(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) { - // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup> - %0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup> - spirv.Return -} - -// ----- - func.func @f_convert_vector(%arg0 : f32) -> f32 { // expected-error @+1 {{expected the different bit widths for operand type and result type, but provided 'f32' and 'f32'}} %0 = spirv.FConvert %arg0 : f32 to f32 @@ -254,14 +238,6 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 { // ----- -func.func @f_convert_coop_matrix_to_nv_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>) { - // expected-error @+1 {{incompatible operand and result types}} - %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc> to !spirv.NV.coopmatrix<8x16xf64, Subgroup> - spirv.Return -} - -// ----- - //===----------------------------------------------------------------------===// // spirv.SConvert //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir index b10677f0f5f99..3fc8dfb2767d1 100644 --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -32,13 +32,6 @@ func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix< return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> } -// CHECK-LABEL: func @composite_construct_coopmatrix_nv -func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { - // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> - return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup> -} - // ----- func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> { @@ -75,22 +68,6 @@ func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32 // ----- -func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { - // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}} - %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> - return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup> -} - -// ----- - -func.func @composite_construct_NV.coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { - // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}} - %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> - return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup> -} - -// ----- - func.func @composite_construct_array(%arg0: f32) -> !spirv.array<4xf32> { // expected-error @+1 {{expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}} %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.array<4xf32> @@ -143,14 +120,6 @@ func.func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 { // ----- -func.func @composite_extract_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) -> f32 { - // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - return %0 : f32 -} - -// ----- - func.func @composite_extract_no_ssa_operand() -> () { // expected-error @+1 {{expected SSA operand}} %0 = spirv.CompositeExtract [4 : i32, 1 : i32] : !spirv.array<4x!spirv.array<4xf32>> @@ -271,14 +240,6 @@ func.func @composite_insert_struct(%arg0: !spirv.struct<(!spirv.array<4xf32>, f3 // ----- -func.func @composite_insert_NV.coopmatrix(%arg0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.NV.coopmatrix<8x16xi32, Subgroup> { - // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup> - return %0: !spirv.NV.coopmatrix<8x16xi32, Subgroup> -} - -// ----- - func.func @composite_insert_no_indices(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> { // expected-error @+1 {{expected at least one index}} %0 = spirv.CompositeInsert %arg1, %arg0[] : f32 into !spirv.array<4xf32> diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index 445ab8a48d3ce..d3e1dbc229ef9 100644 --- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -13,14 +13,6 @@ spirv.func @cooperative_matrix_length() -> i32 "None" { // ----- -spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" { - // expected-error @+1 {{'cooperative_matrix_type' failed to satisfy constraint: type attribute of any SPIR-V cooperative matrix type}} - %0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.ReturnValue %0 : i32 -} - -// ----- - // CHECK-LABEL: @cooperative_matrix_load spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32) "None" { // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, : @@ -118,24 +110,6 @@ spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr, %stride : i32) "None" { - // expected-error @+1 {{expected '<'}} - %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, : - !spirv.ptr, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" { - // expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}} - %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, : - !spirv.ptr, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// ----- - spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr, %stride : i32) "None" { // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}} %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir index f52666af280e4..372fcc6e514b9 100644 --- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir @@ -9,10 +9,10 @@ spirv.module Logical GLSL450 requires #spirv.vce { } // CHECK-LABEL: @matrix_times_scalar_2 - spirv.func @matrix_times_scalar_2(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" { - // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 - spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup> + spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA> "None" { + // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16 + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16 + spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA> } // CHECK-LABEL: @matrix_transpose_1 diff --git a/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir deleted file mode 100644 index 43cbf61b60ef0..0000000000000 --- a/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir +++ /dev/null @@ -1,177 +0,0 @@ -// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s - -//===----------------------------------------------------------------------===// -// NV.CooperativeMatrix -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @cooperative_matrix_load -spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_load_memaccess -spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type -spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_store -spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_store_memaccess -spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_length -spirv.func @cooperative_matrix_length() -> i32 "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.ReturnValue %0 : i32 -} - -// CHECK-LABEL: @cooperative_matrix_muladd -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_add -spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_sub -spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_sdiv -spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_udiv -spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_fadd -spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_fsub -spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - spirv.Return -} - -// CHECK-LABEL: @cooperative_matrix_fdiv -spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - spirv.Return -} - -// ----- - -// CHECK-LABEL: @cooperative_matrix_access_chain -spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { - %0 = spirv.Constant 0: i32 - // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 - %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 - spirv.ReturnValue %1 : !spirv.ptr -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{matrix A and B non-integer element types must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{matrix A and B integer element types must be the same bit width}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { - // expected-error @+1 {{Pointer must point to a scalar or vector type}} - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}} - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixLength' op type attribute must be a '!spirv.NV.coopmatrix'}} - %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> - spirv.ReturnValue %0 : i32 -} diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir index 722e4434aeaf9..6f6ce1202d170 100644 --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -797,7 +797,7 @@ spirv.module Logical GLSL450 { } //===----------------------------------------------------------------------===// -// spirv.SpecConstantComposite (spirv.NV.coopmatrix) +// spirv.SpecConstantComposite (spirv.KHR.coopmatrix) //===----------------------------------------------------------------------===// // ----- @@ -805,7 +805,7 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.SpecConstant @sc1 = 1.5 : f32 // expected-error @+1 {{unsupported composite type}} - spirv.SpecConstantComposite @scc (@sc1) : !spirv.NV.coopmatrix<8x16xf32, Device> + spirv.SpecConstantComposite @scc (@sc1) : !spirv.coopmatrix<8x16xf32, Device, MatrixA> } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index e10a6fc77e856..05ab91b6db6bd 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -479,25 +479,6 @@ func.func private @use_not_integer(!spirv.coopmatrix<8x8xi32, Subgroup, Subgroup // ----- -//===----------------------------------------------------------------------===// -// NV.CooperativeMatrix -//===----------------------------------------------------------------------===// - -// CHECK: func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -func.func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> () - -// ----- - -// expected-error @+1 {{expected ','}} -func.func private @missing_scope(!spirv.NV.coopmatrix<8x16xi32>) -> () - -// ----- - -// expected-error @+1 {{expected rows and columns size}} -func.func private @missing_count(!spirv.NV.coopmatrix<8xi32, Subgroup>) -> () - -// ----- - //===----------------------------------------------------------------------===// // Matrix //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir index af8f41a30d24f..b52c3f4aa2f11 100644 --- a/mlir/test/Target/SPIRV/matrix.mlir +++ b/mlir/test/Target/SPIRV/matrix.mlir @@ -23,10 +23,10 @@ spirv.module Logical GLSL450 requires #spirv.vce { } // CHECK-LABEL: @matrix_times_scalar_3 - spirv.func @matrix_times_scalar_3(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" { - // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16 - spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup> + spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> "None" { + // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16 + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16 + spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> } // CHECK-LABEL: @matrix_transpose_1 diff --git a/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir deleted file mode 100644 index 2eec99f72691c..0000000000000 --- a/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir +++ /dev/null @@ -1,102 +0,0 @@ -// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s - -spirv.module Logical GLSL450 requires #spirv.vce { - // CHECK-LABEL: @cooperative_matrix_load - spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_load_memaccess - spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_store - spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<16x8xi32, Workgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<16x8xi32, Workgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_store_memaccess - spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_length - spirv.func @cooperative_matrix_length() -> i32 "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.ReturnValue %0 : i32 - } - - // CHECK-LABEL: @cooperative_matrix_muladd - spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_add - spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_sub - spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_sdiv - spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_udiv - spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_fadd - spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_fsub - spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_fdiv - spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - spirv.Return - } - - // CHECK-LABEL: @cooperative_matrix_access_chain - spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { - %0 = spirv.Constant 0: i32 - // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 - %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 - spirv.ReturnValue %1 : !spirv.ptr - } -}