Skip to content

[mlir][nvgpu] Improve nvgpu->nvvm transformation of warpgroup.mma Op (NFC) #67325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 5, 2023

Conversation

grypp
Copy link
Member

@grypp grypp commented Sep 25, 2023

This PR introduces substantial improvements to the readability and maintainability of the nvgpu.warpgroup.mma Op transformation from nvgpu->nvvm. This transformation plays a crucial role in GEMM and manages complex operations such as generating multiple wgmma ops and iterating their descriptors. The prior code lacked clarity, but this PR addresses that issue effectively.

PR does followings:
Introduces a helper class: WarpgroupGemm class encapsulates the necessary functionality, making the code cleaner and more understandable.

Detailed Documentation: Each function within the helper class is thoroughly documented to provide clear insights into its purpose and functionality.

@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2023

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Changes

This PR introduces substantial improvements to the readability and maintainability of the nvgpu.warpgroup.mma Op transformation from nvgpu->nvvm. This transformation plays a crucial role in GEMM and manages complex operations such as generating multiple wgmma ops and iterating their descriptors. The prior code lacked clarity, but this PR addresses that issue effectively.

PR does followings:
Introduces a helper class: WarpgroupGemm class encapsulates the necessary functionality, making the code cleaner and more understandable.

Detailed Documentation: Each function within the helper class is thoroughly documented to provide clear insights into its purpose and functionality.


Full diff: https://github.com/llvm/llvm-project/pull/67325.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+238-111)
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..3bbee8934a1d4ae 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -38,7 +38,7 @@ namespace mlir {
 
 using namespace mlir;
 
-/// Number of bits that needs to excluded when building matrix descriptor for
+/// Number of bits that needs to be excluded when building matrix descriptor for
 /// wgmma operations.
 constexpr int exclude4LSB = 4;
 
@@ -1168,140 +1168,267 @@ struct NVGPUWarpgroupMmaOpLowering
     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
   using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
 
-  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
-                              int &wgmmaShapeM, int &wgmmaShapeN,
-                              int &wgmmaShapeK) const {
-    wgmmaShapeM = 64;
-    wgmmaShapeN = sizeN;
-    if (inputElemType.isTF32()) {
-      wgmmaShapeK = 8;
-    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
-      wgmmaShapeK = 16;
-    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
-               inputElemType.isInteger(16)) {
-      wgmmaShapeK = 32;
-    } else if (inputElemType.isInteger(1)) {
-      wgmmaShapeK = 256;
-    } else {
-      llvm_unreachable("msg: not supported K shape");
+  /// This class assists in generating WgmmaMmaAsyncOp instructions to complete
+  /// a specified shape. If the GEMM shape is larger than the shape of a wgmma
+  /// instrution, it can generate multiple wgmma instructions, group and execute
+  /// them asynchronously. The class also handles waiting for instruction
+  /// completion and iterates through GenerateGmmaDescriptor to create
+  /// descriptors for each instruction.
+  class WarpgroupGemm {
+    nvgpu::WarpgroupMmaOp op;
+    ConversionPatternRewriter &rewriter;
+    OpAdaptor adaptor;
+    const LLVMTypeConverter &typeConverter;
+
+    // Entire shape of the given Op
+    int64_t totalM, totalN, totalK;
+
+    // Shape of one wgmma instruction
+    int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
+
+    // Iteration counts for GEMM
+    int iterationM = 0, iterationN = 0, iterationK = 0;
+
+    /// The function returns the shape of wgmma instruction that is defined in
+    /// PTX programming guide.
+    /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
+    void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
+      wgmmaM = 64;
+      wgmmaN = sizeN;
+      if (inputElemType.isTF32()) {
+        wgmmaK = 8;
+      } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+        wgmmaK = 16;
+      } else if (inputElemType.isFloat8E4M3FN() ||
+                 inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
+        wgmmaK = 32;
+      } else if (inputElemType.isInteger(1)) {
+        wgmmaK = 256;
+      } else {
+        llvm_unreachable("msg: not supported K shape");
+      }
+      LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+                        << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
     }
-    LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
-                      << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
-                      << "]\n");
-    return success();
-  }
 
-  Value generateNVVMWgmmaOp(MLIRContext *ctx,
-                            ConversionPatternRewriter &rewriter, Location loc,
-                            int m, int n, int k, Type resultStructType,
-                            Value inout, Value descriptorA,
-                            Value descriptorB) const {
-    auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
-    auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
-    auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
-    auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
-    auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
-    // todo: handle other input and output types
-    auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
-    auto overflow =
-        NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
-    Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
-        loc, resultStructType, inout, descriptorA, descriptorB, shape, itype,
-        itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
-    return res;
-  }
-
-  LogicalResult
-  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
-    int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
-    int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
-
-    LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
-                      << sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
-                      << sizeN << "] ---===\n");
-
-    int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
-    if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
-                             wgmmaShapeN, wgmmaShapeK))) {
-      return failure();
+    /// Generates WGMMATypesAttr from MLIR Type
+    NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
+      auto getWgmmaType = [](Type elemType) {
+        if (elemType.isF32() || elemType.isTF32())
+          return NVVM::WGMMATypes::tf32;
+        if (elemType.isF16())
+          return NVVM::WGMMATypes::f16;
+        if (elemType.isBF16())
+          return NVVM::WGMMATypes::bf16;
+        if (elemType.isFloat8E4M3FN())
+          return NVVM::WGMMATypes::e4m3;
+        if (elemType.isFloat8E5M2())
+          return NVVM::WGMMATypes::e5m2;
+        if (elemType.isInteger(1))
+          return NVVM::WGMMATypes::b1;
+        if (elemType.isInteger(8))
+          return NVVM::WGMMATypes::s8;
+        if (elemType.isUnsignedInteger(8))
+          return NVVM::WGMMATypes::u8;
+        llvm_unreachable("unsupported type");
+      };
+      return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
     }
 
-    Value descriptorA = adaptor.getDescriptorA();
-    Value descriptorB = adaptor.getDescriptorB();
+    /// Generates layout attribute for the input matrix for wgmma instruction
+    NVVM::MMALayoutAttr
+    generateWgmmaLayout(std::optional<bool> transpose) const {
+      if (transpose.value_or(false))
+        return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
+      return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
+    }
 
-    //  Generate wgmma group
+    /// Generates shape attribute for wgmma instruction
+    NVVM::MMAShapeAttr generateWgmmaShape() const {
+      return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
+    }
 
-    auto loc = op->getLoc();
-    MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
-    MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
+    /// Generates scale attributes of output matrix for wgmma instruction
+    NVVM::WGMMAScaleOutAttr generateScaleOut() const {
+      return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
+                                          NVVM::WGMMAScaleOut::one);
+    }
+    /// Generates scale attributes of input matrix for wgmma instruction
+    NVVM::WGMMAScaleInAttr generateScaleIn() const {
+      return NVVM::WGMMAScaleInAttr::get(op->getContext(),
+                                         NVVM::WGMMAScaleIn::one);
+    }
 
-    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
-      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+    /// Basic function to generate Add
+    Value makeAdd(Value lhs, Value rhs) {
+      return rewriter.create<LLVM::AddOp>(op->getLoc(), lhs.getType(), lhs,
+                                          rhs);
     };
 
-    auto iterateDescA = [&](Value desc, int iterM, int iterN,
-                            int iterK) -> Value {
-      // todo : Handle column major
-      int byte = typeTensorA.getElementTypeBitWidth() / 8;
-      int tileShapeA = typeTensorA.getDimSize(1);
-      int incrementVal =
-          ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
+    /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
+    /// Currently, it only handles row-major.
+    ///
+    /// It moves the pointer like below for [128][64] size:
+    ///                 +2 +4 +6
+    ///                  ↓  ↓  ↓
+    /// descA    ---> +--+--+--+--+
+    ///               |->|->|->|->|
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    /// descA+512---> +-----------+
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               |  |  |  |  |
+    ///               +-----------+
+    ///
+    Value iterateDescriptorA(Value desc, int i, int j, int k) {
+      MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
+      Type elemA = matrixTypeA.getElementType();
+      int byte = elemA.getIntOrFloatBitWidth() / 8;
+      int tileShapeA = matrixTypeA.getDimSize(1);
+      int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
       incrementVal = incrementVal >> exclude4LSB;
-      LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
-                        << iterK << "] [wgmma descriptors] Descriptor A + "
+      LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
+                        << "] [wgmma descriptors] Descriptor A + "
                         << incrementVal << " | \t ");
       if (!incrementVal)
         return desc;
       return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
-    };
+    }
 
-    auto iterateDescB = [&](Value desc, int iterM, int iterN,
-                            int iterK) -> Value {
-      // todo : Handle row major
-      int byte = typeTensorB.getElementTypeBitWidth() / 8;
-      int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
+    /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
+    /// Currently, it only handles column-major.
+    ///
+    /// It moves the pointer like below for [128][64] size:
+    /// descB     ---> +--+--+--+--+--+--+--+--+
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                |↓ |  |  |  |  |  |  |  |
+    ///                +--+--+--+--+--+--+--+--+
+    ///
+    Value iterateDescriptorB(Value desc, int i, int j, int k) {
+      MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
+      Type elemB = matrixTypeB.getElementType();
+      int byte = elemB.getIntOrFloatBitWidth() / 8;
+      int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
       incrementVal = incrementVal >> exclude4LSB;
       LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
       if (!incrementVal)
         return desc;
       return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
-    };
+    }
+
+    /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
+    /// descriptors and arranges them based on induction variables: i, j, and k.
+    Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+      LLVM_DEBUG(DBGS() << "\t wgmma."
+                        << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+                        << "(A[" << (iterationM * wgmmaM) << ":"
+                        << (iterationM * wgmmaM) + wgmmaM << "]["
+                        << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "] * "
+                        << " B[" << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
+                        << wgmmaN << "])\n");
+
+      Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
+      Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
+
+      Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
+      NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
+
+      Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
+      NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
+
+      NVVM::MMAShapeAttr shape = generateWgmmaShape();
+      NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
+      NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
+      NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
+      NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
+
+      auto overflow = NVVM::MMAIntOverflowAttr::get(
+          op->getContext(), NVVM::MMAIntOverflow::wrapped);
+
+      Type resultStructType = typeConverter.convertType(matrixD.getType());
+
+      return rewriter.create<NVVM::WgmmaMmaAsyncOp>(
+          op->getLoc(), resultStructType, matrixC, descriptorA, descriptorB,
+          shape, itypeA, itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+          overflow);
+    }
 
-    rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
-
-    SmallVector<Value> wgmmaResults;
-    for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
-      Value matrixC = adaptor.getMatrixC()[iterM];
-      Value matrixD = op.getMatrixD()[iterM];
-      Type structType = getTypeConverter()->convertType(matrixD.getType());
-      LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
-                        << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
-                        << ":" << wgmmaShapeN << "] += \n");
-      for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
-        Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
-        Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
-        LLVM_DEBUG(DBGS() << "\t wgmma."
-                          << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
-                          << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
-                          << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
-                          << (iterK * wgmmaShapeK) << ":"
-                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
-                          << " B[" << (iterK * wgmmaShapeK) << ":"
-                          << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
-                          << ":" << wgmmaShapeN << "])\n");
-        matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
-                                      wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
-                                      structType, matrixC, descA, descB);
+    /// Generates multiple wgmma instructions to complete the given GEMM shape
+    SmallVector<Value> generateWgmmaGroup() {
+      SmallVector<Value> wgmmaResults;
+
+      // Perform GEMM
+      for (int i = 0; i < iterationM; ++i) {
+        Value matrixC = adaptor.getMatrixC()[i];
+        Value matrixD = op.getMatrixD()[i];
+        for (int j = 0; j < iterationN; ++j)
+          for (int k = 0; k < iterationK; ++k)
+            matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+        wgmmaResults.push_back(matrixC);
       }
-      wgmmaResults.push_back(matrixC);
+
+      return wgmmaResults;
+    }
+
+  public:
+    WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ConversionPatternRewriter &rewriter,
+                  OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
+        : op(op), rewriter(rewriter), adaptor(adaptor),
+          typeConverter(typeConverter) {
+      // Find the entire GEMM Shape
+      totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
+      totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
+      totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+      LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
+                        << "] += A[" << totalM << "][" << totalK << "] * B["
+                        << totalK << "][" << totalN << "] ---===\n");
+
+      // Find the shape for one wgmma instruction
+      findWgmmaShape(
+          totalM, totalN,
+          op.getDescriptorA().getType().getTensor().getElementType());
+
+      // Iterations counts to complete the given shape with wgmma shape
+      iterationM = totalM / wgmmaM;
+      iterationN = totalN / wgmmaN;
+      iterationK = totalK / wgmmaK;
     }
-    rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
-    rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
 
-    ValueRange myres(wgmmaResults);
-    rewriter.replaceOp(op, myres);
+    /// Generates WgmmaMmaAsync Ops to complete the specified GEMM  shape. It
+    /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
+    /// instructions and group synchronization, as well as waiting
+    /// (WgmmaGroupSyncAlignedOp) for group synchronization
+    /// (WgmmaWaitGroupSyncOp) after the instructions.
+    SmallVector<Value> generateWarpgroupMma() {
+      Location loc = op->getLoc();
+      rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
+      SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+      rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
+      rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
+
+      return wgmmaResults;
+    }
+  };
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Step 1. Build a helper class
+    WarpgroupGemm warpgroupGemm(op, rewriter, adaptor,
+                                *this->getTypeConverter());
+
+    // Step 2. Get the entire GEMM Shape
+    SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+
+    // Step 3. Replace fragmented result struct with the op results
+    rewriter.replaceOp(op, wgmmaResults);
     return success();
   }
 };

grypp added 2 commits October 5, 2023 10:03
…mma` Op (NFC)

This PR introduces substantial improvements to the readability and maintainability of the `nvgpu.warpgroup.mma` Op transformation from nvgpu->nvvm. This transformation plays a crucial role in GEMM and manages complex operations such as generating multiple wgmma ops and iterating their descriptors. The prior code lacked clarity, but this PR addresses that issue effectively.

PR introduces a helper class `WarpgroupGemm`. This class encapsulates the necessary functionality, making the code cleaner and more understandable. Each function within the helper class is thoroughly documented to provide clear insights into its purpose and functionality.
@grypp grypp force-pushed the nvgpu-mma-better-code branch from 1840749 to 2ca1556 Compare October 5, 2023 08:14
@grypp grypp merged commit b74cfc1 into llvm:main Oct 5, 2023
@grypp grypp deleted the nvgpu-mma-better-code branch October 5, 2023 08:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants