Skip to content

Commit cda7d50

Browse files
committed
[mlir][spirv][gpu] Add conversion for load/store/mad coop matrix ops
This is plugged in as an alternative lowering path in the gpu to spirv dialect conversion. The remaining lowering patterns will be added in a future patch.
1 parent 2a07f0f commit cda7d50

File tree

7 files changed

+278
-11
lines changed

7 files changed

+278
-11
lines changed

mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,21 @@ class MMAMatrixType;
3030
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
3131
RewritePatternSet &patterns);
3232

33+
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
34+
/// using the KHR Cooperative Matrix extension.
35+
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
36+
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
37+
3338
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
3439
/// using the NV Cooperative Matrix extension.
3540
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
3641
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
3742

43+
/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
44+
/// `type`.
45+
spirv::CooperativeMatrixType
46+
convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);
47+
3848
/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
3949
/// `type`.
4050
spirv::CooperativeMatrixNVType

mlir/include/mlir/Conversion/Passes.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,11 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
567567
let options = [
568568
Option<"use64bitIndex", "use-64bit-index",
569569
"bool", /*default=*/"false",
570-
"Use 64-bit integers to convert index types">
570+
"Use 64-bit integers to convert index types">,
571+
Option<"useCoopMatrixNV", "use-coop-matrix-nv",
572+
"bool", /*default=*/"true",
573+
"Use the NV cooperative matrix extension insted of the KHR extension"
574+
" to lower GPU WMMA ops">,
571575
];
572576
}
573577

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,15 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
146146
let results = (outs
147147
SPIRV_AnyCooperativeMatrix:$result
148148
);
149+
150+
let builders = [
151+
OpBuilder<(ins "Type":$result, "Value":$pointer,
152+
"spirv::ConstantOp":$stride,
153+
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
154+
build($_builder, $_state, result, pointer, layout, stride,
155+
spirv::MemoryAccessAttr{});
156+
}]>
157+
];
149158
}
150159

151160
// -----
@@ -226,6 +235,15 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
226235
);
227236

228237
let results = (outs);
238+
239+
let builders = [
240+
OpBuilder<(ins "Value":$pointer, "Value":$object,
241+
"spirv::ConstantOp":$stride,
242+
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
243+
build($_builder, $_state, pointer, object, layout, stride,
244+
spirv::MemoryAccessAttr{});
245+
}]>
246+
];
229247
}
230248

231249
// -----
@@ -332,6 +350,13 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
332350
let results = (outs
333351
SPIRV_AnyCooperativeMatrix:$result
334352
);
353+
354+
let builders = [
355+
OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c), [{
356+
build($_builder, $_state, a, b, c,
357+
spirv::CooperativeMatrixOperandsKHRAttr{});
358+
}]>
359+
];
335360
}
336361

337362
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,25 @@ void GPUToSPIRVPass::runOnOperation() {
8686
SPIRVConversionOptions options;
8787
options.use64bitIndex = this->use64bitIndex;
8888
SPIRVTypeConverter typeConverter(targetAttr, options);
89-
typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
90-
return convertMMAToSPIRVCoopMatrixNVType(type);
89+
90+
typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
91+
gpu::MMAMatrixType type) -> Type {
92+
if (useNV)
93+
return convertMMAToSPIRVCoopMatrixNVType(type);
94+
95+
return convertMMAToSPIRVCoopMatrixType(type);
9196
});
97+
9298
RewritePatternSet patterns(context);
9399
populateGPUToSPIRVPatterns(typeConverter, patterns);
94-
populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
95-
patterns);
100+
if (this->useCoopMatrixNV) {
101+
populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
102+
patterns);
103+
} else {
104+
populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
105+
patterns);
106+
}
107+
96108
// TODO: Change SPIR-V conversion to be progressive and remove the following
97109
// patterns.
98110
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 140 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,28 @@
1818
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1919
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
2020
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21+
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
2122
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
2223
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
24+
#include "mlir/IR/BuiltinAttributes.h"
25+
#include "mlir/IR/BuiltinTypes.h"
2326
#include "mlir/IR/TypeUtilities.h"
27+
#include "llvm/ADT/StringSwitch.h"
2428

25-
namespace mlir::nv {
26-
namespace {
29+
#include <cassert>
2730

31+
namespace mlir {
2832
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
2933
/// when the elementwise op directly supports with cooperative matrix type.
3034
/// Returns false if cannot.
3135
///
3236
/// See SPV_NV_cooperative_matrix for supported elementwise ops.
3337
static bool createElementwiseOp(ConversionPatternRewriter &builder,
34-
gpu::SubgroupMmaElementwiseOp op,
35-
spirv::CooperativeMatrixNVType coopType,
38+
gpu::SubgroupMmaElementwiseOp op, Type coopType,
3639
ValueRange operands) {
40+
assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
41+
coopType)));
42+
3743
switch (op.getOpType()) {
3844
case gpu::MMAElementwiseOp::ADDF:
3945
builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
@@ -71,6 +77,110 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
7177
return false;
7278
}
7379

80+
//===----------------------------------------------------------------------===//
81+
// SPV_KHR_cooperative_matrix
82+
//===----------------------------------------------------------------------===//
83+
84+
namespace khr {
85+
namespace {
86+
87+
/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
88+
/// dialect.
89+
struct WmmaLoadOpToSPIRVLowering final
90+
: OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
91+
using OpConversionPattern::OpConversionPattern;
92+
93+
LogicalResult
94+
matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
95+
ConversionPatternRewriter &rewriter) const override {
96+
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
97+
Location loc = op->getLoc();
98+
99+
auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
100+
MemRefType memrefType = op.getSrcMemref().getType();
101+
Value bufferPtr =
102+
spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
103+
adaptor.getIndices(), loc, rewriter);
104+
105+
auto coopType =
106+
typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
107+
if (!coopType)
108+
return rewriter.notifyMatchFailure(op, "type conversion failed");
109+
110+
int64_t stride = op.getLeadDimension().getSExtValue();
111+
IntegerType i32Type = rewriter.getI32Type();
112+
auto strideValue = rewriter.create<spirv::ConstantOp>(
113+
loc, i32Type, IntegerAttr::get(i32Type, stride));
114+
115+
bool isColMajor = op.getTranspose().value_or(false);
116+
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
117+
: spirv::CooperativeMatrixLayoutKHR::RowMajor;
118+
119+
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
120+
op, coopType, bufferPtr, strideValue, layout);
121+
return success();
122+
}
123+
};
124+
125+
/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
126+
/// dialect.
127+
struct WmmaStoreOpToSPIRVLowering final
128+
: OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
129+
using OpConversionPattern::OpConversionPattern;
130+
131+
LogicalResult
132+
matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
133+
ConversionPatternRewriter &rewriter) const override {
134+
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
135+
Location loc = op->getLoc();
136+
137+
auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
138+
Value bufferPtr =
139+
spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
140+
adaptor.getIndices(), loc, rewriter);
141+
142+
int64_t stride = op.getLeadDimension().getSExtValue();
143+
IntegerType i32Type = rewriter.getI32Type();
144+
auto strideValue = rewriter.create<spirv::ConstantOp>(
145+
loc, i32Type, IntegerAttr::get(i32Type, stride));
146+
147+
bool isColMajor = op.getTranspose().value_or(false);
148+
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
149+
: spirv::CooperativeMatrixLayoutKHR::RowMajor;
150+
151+
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
152+
op, bufferPtr, adaptor.getSrc(), strideValue, layout);
153+
return success();
154+
}
155+
};
156+
157+
/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
158+
/// dialect.
159+
struct WmmaMmaOpToSPIRVLowering final
160+
: OpConversionPattern<gpu::SubgroupMmaComputeOp> {
161+
using OpConversionPattern::OpConversionPattern;
162+
163+
LogicalResult
164+
matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
165+
OpAdaptor adaptor,
166+
ConversionPatternRewriter &rewriter) const override {
167+
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
168+
subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
169+
adaptor.getOpC());
170+
return success();
171+
}
172+
};
173+
174+
} // namespace
175+
} // namespace khr
176+
177+
//===----------------------------------------------------------------------===//
178+
// SPV_NV_cooperative_matrix
179+
//===----------------------------------------------------------------------===//
180+
181+
namespace nv {
182+
namespace {
183+
74184
/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
75185
/// dialect.
76186
struct WmmaLoadOpToSPIRVLowering final
@@ -247,7 +357,8 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
247357
};
248358

249359
} // namespace
250-
} // namespace mlir::nv
360+
} // namespace nv
361+
} // namespace mlir
251362

252363
mlir::spirv::CooperativeMatrixNVType
253364
mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
@@ -257,6 +368,30 @@ mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
257368
elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
258369
}
259370

371+
mlir::spirv::CooperativeMatrixType
372+
mlir::convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type) {
373+
ArrayRef<int64_t> retTypeShape = type.getShape();
374+
Type elementType = type.getElementType();
375+
376+
auto use =
377+
llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
378+
.Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
379+
.Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
380+
.Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
381+
382+
return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
383+
retTypeShape[1],
384+
spirv::Scope::Subgroup, use);
385+
}
386+
387+
void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
388+
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
389+
using namespace mlir;
390+
MLIRContext *context = patterns.getContext();
391+
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
392+
khr::WmmaStoreOpToSPIRVLowering>(converter, context);
393+
}
394+
260395
void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
261396
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
262397
using namespace mlir;
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \
2+
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
3+
4+
module attributes {
5+
gpu.container_module,
6+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
7+
[Shader, CooperativeMatrixKHR, Float16],
8+
[SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]>,
9+
#spirv.resource_limits<>>} {
10+
11+
gpu.module @kernels {
12+
// CHECK-LABEL: spirv.func @gpu_wmma_load_op
13+
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
14+
gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
15+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
16+
%i = arith.constant 16 : index
17+
%j = arith.constant 16 : index
18+
// CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32
19+
// CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <RowMajor> :
20+
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
21+
%0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} :
22+
memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
23+
24+
// CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <ColumnMajor> :
25+
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
26+
%1 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} :
27+
memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
28+
// CHECK: spirv.Return
29+
gpu.return
30+
}
31+
32+
// CHECK-LABEL: spirv.func @gpu_wmma_store_op
33+
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
34+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
35+
gpu.func @gpu_wmma_store_op(%arg0: memref<32x32xf16, #spirv.storage_class<StorageBuffer>>,
36+
%arg1: !gpu.mma_matrix<16x16xf16, "COp">) kernel
37+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
38+
%i = arith.constant 16 : index
39+
%j = arith.constant 16 : index
40+
// CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32
41+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <RowMajor> :
42+
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
43+
gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index} :
44+
!gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
45+
46+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <ColumnMajor> :
47+
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
48+
gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index, transpose} :
49+
!gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
50+
// CHECK: spirv.Return
51+
gpu.return
52+
}
53+
54+
// CHECK-LABEL: spirv.func @gpu_wmma_mma_op
55+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
56+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
57+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
58+
gpu.func @gpu_wmma_mma_op(%A: !gpu.mma_matrix<16x16xf16, "AOp">,
59+
%B: !gpu.mma_matrix<16x16xf16, "BOp">,
60+
%C: !gpu.mma_matrix<16x16xf16, "COp">,
61+
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
62+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
63+
// CHECK: %[[MAD:.*]] = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
64+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>,
65+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
66+
// CHECK-SAME: -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
67+
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">,
68+
!gpu.mma_matrix<16x16xf16, "BOp">
69+
-> !gpu.mma_matrix<16x16xf16, "COp">
70+
71+
%i = arith.constant 0 : index
72+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
73+
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
74+
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
75+
// CHECK: spirv.Return
76+
gpu.return
77+
}
78+
79+
}
80+
}

mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s
1+
// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=true" \
2+
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
23

34
module attributes {
45
gpu.container_module,

0 commit comments

Comments
 (0)