diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 2d43230938526..91b6a25c6dfc0 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -1576,27 +1576,34 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op->getLoc(), rewriter); - LLVM::LLVMStructType structType = + LLVM::LLVMStructType packStructType = getTypeConverter() ->convertType(op.getMatrixC().getType()) .cast(); - Type elemType = structType.getBody() + Type elemType = packStructType.getBody() .front() .cast() .getBody() .front(); Value zero = b.create(elemType, b.getZeroAttr(elemType)); - Value structValue = b.create(structType); - for (auto [idx, s] : llvm::enumerate(structType.getBody())) { - auto innerStructType = s.cast(); - int ii = idx; - Value innerStructValue = b.create(structValue, ii); - for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) { - innerStructValue = b.create( - innerStructType, innerStructValue, zero, ArrayRef({i})); + Value packStruct = b.create(packStructType); + SmallVector innerStructs; + // Unpack the structs and set all values to zero + for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { + auto structType = s.cast(); + Value structValue = b.create(packStruct, idx); + for (unsigned i = 0; i < structType.getBody().size(); ++i) { + structValue = b.create( + structType, structValue, zero, ArrayRef({i})); } + innerStructs.push_back(structValue); } - rewriter.replaceOp(op, structValue); + // Pack the inner structs into a single struct + for (auto [idx, matrix] : llvm::enumerate(innerStructs)) { + packStruct = b.create(packStruct.getType(), + packStruct, matrix, idx); + } + rewriter.replaceOp(op, packStruct); return success(); } };