Skip to content

[mlir][nvgpu] Improve nvgpu->nvvm transformation of warpgroup.mma Op (NFC) #67325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 246 additions & 107 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace mlir {

using namespace mlir;

/// Number of bits that needs to excluded when building matrix descriptor for
/// Number of bits that needs to be excluded when building matrix descriptor for
/// wgmma operations.
constexpr int exclude4LSB = 4;

Expand Down Expand Up @@ -1160,137 +1160,276 @@ struct NVGPUWarpgroupMmaOpLowering
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;

LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
int &wgmmaShapeM, int &wgmmaShapeN,
int &wgmmaShapeK) const {
wgmmaShapeM = 64;
wgmmaShapeN = sizeN;
if (inputElemType.isTF32()) {
wgmmaShapeK = 8;
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
wgmmaShapeK = 16;
} else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
inputElemType.isInteger(16)) {
wgmmaShapeK = 32;
} else if (inputElemType.isInteger(1)) {
wgmmaShapeK = 256;
} else {
llvm_unreachable("msg: not supported K shape");
/// This is a helper class to generate required NVVM Ops for warp-group level
/// matrix multiplication.
/// When the given GEMM shape is larger than the shape of
/// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
/// Op(s), group and execute them asynchronously. The class also handles
/// waiting for completion and iterates through WarpgroupMatrixDescriptor to
/// create descriptors for each instruction.
///
/// For example this is the case when the shape of GEMM is 128x128x128
///
/// nvvm.wgmma.fence.aligned
///
/// nvvm.wgmma.mma.async descA, descB
/// iterate(descA, descB)
/// nvvm.wgmma.mma.async descA, descB
/// [6x times more]
///
/// nvvm.wgmma.group.sync.aligned
/// nvvm.wgmma.wait.group.sync [groupId]
///
class WarpgroupGemm {
nvgpu::WarpgroupMmaOp op;
ImplicitLocOpBuilder b;
OpAdaptor adaptor;
const LLVMTypeConverter &typeConverter;

// Entire shape of the given Op
int64_t totalM, totalN, totalK;

// Shape of one wgmma instruction
int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;

// Iteration counts for GEMM
int iterationM = 0, iterationN = 0, iterationK = 0;

/// The function returns the shape of wgmma instruction that is defined in
/// PTX programming guide.
/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
wgmmaM = 64;
wgmmaN = sizeN;
if (inputElemType.isTF32()) {
wgmmaK = 8;
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
wgmmaK = 16;
} else if (inputElemType.isFloat8E4M3FN() ||
inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
wgmmaK = 32;
} else if (inputElemType.isInteger(1)) {
wgmmaK = 256;
} else {
llvm_unreachable("msg: not supported K shape");
}
LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
<< ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
}
LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
<< ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
<< "]\n");
return success();
}

Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
Type resultStructType, Value inout,
Value descriptorA, Value descriptorB) const {
MLIRContext *ctx = b.getContext();
auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
// todo: handle other input and output types
auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
auto overflow =
NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
return res;
}

LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);

LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
<< sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
<< sizeN << "] ---===\n");
/// Generates WGMMATypesAttr from MLIR Type
NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
auto getWgmmaType = [](Type elemType) {
if (elemType.isF32() || elemType.isTF32())
return NVVM::WGMMATypes::tf32;
if (elemType.isF16())
return NVVM::WGMMATypes::f16;
if (elemType.isBF16())
return NVVM::WGMMATypes::bf16;
if (elemType.isFloat8E4M3FN())
return NVVM::WGMMATypes::e4m3;
if (elemType.isFloat8E5M2())
return NVVM::WGMMATypes::e5m2;
if (elemType.isInteger(1))
return NVVM::WGMMATypes::b1;
if (elemType.isInteger(8))
return NVVM::WGMMATypes::s8;
if (elemType.isUnsignedInteger(8))
return NVVM::WGMMATypes::u8;
llvm_unreachable("unsupported type");
};
return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
}

int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
wgmmaShapeN, wgmmaShapeK))) {
return failure();
/// Generates layout attribute for the input matrix for wgmma instruction
NVVM::MMALayoutAttr
generateWgmmaLayout(std::optional<bool> transpose) const {
if (transpose.value_or(false))
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
}

Value descriptorA = adaptor.getDescriptorA();
Value descriptorB = adaptor.getDescriptorB();
/// Generates shape attribute for wgmma instruction
NVVM::MMAShapeAttr generateWgmmaShape() const {
return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
}

// Generate wgmma group
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
/// Generates scale attributes of output matrix for wgmma instruction
NVVM::WGMMAScaleOutAttr generateScaleOut() const {
return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
NVVM::WGMMAScaleOut::one);
}
/// Generates scale attributes of input matrix for wgmma instruction
NVVM::WGMMAScaleInAttr generateScaleIn() const {
return NVVM::WGMMAScaleInAttr::get(op->getContext(),
NVVM::WGMMAScaleIn::one);
}

auto makeAdd = [&](Value lhs, Value rhs) -> Value {
/// Basic function to generate Add
Value makeAdd(Value lhs, Value rhs) {
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
};

auto iterateDescA = [&](Value desc, int iterM, int iterN,
int iterK) -> Value {
// todo : Handle column major
int byte = typeTensorA.getElementTypeBitWidth() / 8;
int tileShapeA = typeTensorA.getDimSize(1);
int incrementVal =
((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
/// Currently, it only handles row-major.
///
/// It moves the pointer like below for [128][64] size:
/// +2 +4 +6
/// ↓ ↓ ↓
/// descA ---> +--+--+--+--+
/// |->|->|->|->|
/// | | | | |
/// | | | | |
/// | | | | |
/// descA+512---> +-----------+
/// | | | | |
/// | | | | |
/// | | | | |
/// | | | | |
/// +-----------+
///
Value iterateDescriptorA(Value desc, int i, int j, int k) {
MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
Type elemA = matrixTypeA.getElementType();
int byte = elemA.getIntOrFloatBitWidth() / 8;
int tileShapeA = matrixTypeA.getDimSize(1);
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
<< iterK << "] [wgmma descriptors] Descriptor A + "
LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
<< "] [wgmma descriptors] Descriptor A + "
<< incrementVal << " | \t ");
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
};
}

auto iterateDescB = [&](Value desc, int iterM, int iterN,
int iterK) -> Value {
// todo : Handle row major
int byte = typeTensorB.getElementTypeBitWidth() / 8;
int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
/// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
/// Currently, it only handles column-major.
///
/// It moves the pointer like below for [128][64] size:
/// descB ---> +--+--+--+--+--+--+--+--+
/// |↓ | | | | | | | |
/// |↓ | | | | | | | |
/// |↓ | | | | | | | |
/// |↓ | | | | | | | |
/// +--+--+--+--+--+--+--+--+
///
Value iterateDescriptorB(Value desc, int i, int j, int k) {
MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
Type elemB = matrixTypeB.getElementType();
int byte = elemB.getIntOrFloatBitWidth() / 8;
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
};
}

/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
<< "(A[" << (iterationM * wgmmaM) << ":"
<< (iterationM * wgmmaM) + wgmmaM << "]["
<< (iterationK * wgmmaK) << ":"
<< (iterationK * wgmmaK + wgmmaK) << "] * "
<< " B[" << (iterationK * wgmmaK) << ":"
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
<< wgmmaN << "])\n");

Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);

Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);

Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);

NVVM::MMAShapeAttr shape = generateWgmmaShape();
NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());

auto overflow = NVVM::MMAIntOverflowAttr::get(
op->getContext(), NVVM::MMAIntOverflow::wrapped);

Type resultStructType = typeConverter.convertType(matrixD.getType());

return b.create<NVVM::WgmmaMmaAsyncOp>(
resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
}

b.create<NVVM::WgmmaFenceAlignedOp>();

SmallVector<Value> wgmmaResults;
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
Value matrixC = adaptor.getMatrixC()[iterM];
Value matrixD = op.getMatrixD()[iterM];
Type structType = getTypeConverter()->convertType(matrixD.getType());
LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
<< (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
<< ":" << wgmmaShapeN << "] += \n");
for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
<< wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
<< ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
<< (iterK * wgmmaShapeK) << ":"
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
<< " B[" << (iterK * wgmmaShapeK) << ":"
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
<< ":" << wgmmaShapeN << "])\n");
matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
structType, matrixC, descA, descB);
/// Generates multiple wgmma instructions to complete the given GEMM shape
SmallVector<Value> generateWgmmaGroup() {
SmallVector<Value> wgmmaResults;

// Perform GEMM
for (int i = 0; i < iterationM; ++i) {
Value matrixC = adaptor.getMatrixC()[i];
Value matrixD = op.getMatrixD()[i];
for (int j = 0; j < iterationN; ++j)
for (int k = 0; k < iterationK; ++k)
matrixC = generateWgmma(i, j, k, matrixC, matrixD);
wgmmaResults.push_back(matrixC);
}
wgmmaResults.push_back(matrixC);

return wgmmaResults;
}

public:
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
: op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
// Find the entire GEMM Shape
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
<< "] += A[" << totalM << "][" << totalK << "] * B["
<< totalK << "][" << totalN << "] ---===\n");

// Find the shape for one wgmma instruction
findWgmmaShape(
totalM, totalN,
op.getDescriptorA().getType().getTensor().getElementType());

// Iterations counts to complete the given shape with wgmma shape
iterationM = totalM / wgmmaM;
iterationN = totalN / wgmmaN;
iterationK = totalK / wgmmaK;
}
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());

ValueRange myres(wgmmaResults);
rewriter.replaceOp(op, myres);
/// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
/// includes generating a fence Op (WgmmaFenceAlignedOp) before the
/// instructions and group synchronization, as well as waiting
/// (WgmmaGroupSyncAlignedOp) for group synchronization
/// (WgmmaWaitGroupSyncOp) after the instructions.
SmallVector<Value> generateWarpgroupMma() {
b.create<NVVM::WgmmaFenceAlignedOp>();
SmallVector<Value> wgmmaResults = generateWgmmaGroup();
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
return wgmmaResults;
}
};

LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
// Step 1. Build a helper class
WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());

// Step 2. Get the entire GEMM Shape
SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();

// Step 3. Replace fragmented result struct with the op results
rewriter.replaceOp(op, wgmmaResults);
return success();
}
};
Expand Down