diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 45d608536a82..ca12088f0cfd 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -1025,6 +1025,14 @@ void validateMma(Fusion* fusion) { case MmaOptions::MacroType::Volta_16_16_4: validateMinimumArch(7, 0); break; + case MmaOptions::MacroType::Turing_16_8_16: + validateMinimumArch(7, 5); + + // Check that operands come from ldmatrix, can be + // relaxed once swizzles can be labeled on iterdomains. + validateTuringMmaInput(mma->inA()->as()); + validateTuringMmaInput(mma->inB()->as()); + break; case MmaOptions::MacroType::Ampere_16_8_16: validateMinimumArch(8, 0); diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index bf9347d69453..9ac067eb8b77 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -18,6 +18,7 @@ MmaBuilder::MmaBuilder( case MmaOptions::MacroType::Volta_16_16_4: option_.accumulator_stride = outer_stride * 4; break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: option_.accumulator_stride = outer_stride * 2; break; @@ -58,6 +59,7 @@ namespace { LoadStoreOpType getLdMatrixType(MmaOptions options) { bool transpose = false; switch (options.macro) { + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: // Turing mma assumes TN as default transpose = (options.operand == MmaOptions::Operand::A && @@ -84,7 +86,7 @@ bool isVolta(MmaOptions::MacroType macro) { } bool isTuring(MmaOptions::MacroType macro) { - return false; + return macro == MmaOptions::MacroType::Turing_16_8_16; } bool isAmpere(MmaOptions::MacroType macro) { @@ -96,6 +98,7 @@ int getOutputRegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 8; break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: return 4; break; @@ -111,6 +114,7 @@ int getInputARegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 4; break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: return 8; break; @@ -126,6 +130,7 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 4; break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: return 4; default: @@ -176,6 +181,7 @@ std::string toString(MmaOptions::MacroType mt) { case MmaOptions::MacroType::Volta_16_16_4: ss << "M16N16K4"; break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: ss << "M16N8K16"; break; diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index 30d7d2e34f23..6b94d74a4f5b 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -58,6 +58,7 @@ struct MmaOptions { NoMMA = 0, Volta_16_16_4, Ampere_16_8_16, + Turing_16_8_16, Ampere_16_8_8 // place holder for tf32 }; @@ -73,7 +74,7 @@ struct MmaOptions { enum class MmaInputLayout { NT = 0, TT, TN }; //! Utility to annotate which input of mma this option struct describes - enum class Operand { NotOperand = 0, A, B }; + enum class Operand { Accumulator = 0, A, B }; //! Utility to annotate which mma macro this config uses. MacroType macro = MacroType::NoMMA; @@ -117,7 +118,7 @@ class TORCH_CUDA_CU_API MmaBuilder { //! Specifies which element in the mma op this builder is generating //! parameters for, i.e. A or B. This is useful when generating //! data swizzles for different elements of mma. - //! - Operand::NotOperand means the parameters describe accumulator in mma + //! - Operand::Accumulator means the parameters describe accumulator in mma //! op. //! - This option is ignored when configuring the mma operator itself. MmaBuilder& operand(MmaOptions::Operand a_or_b); diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 060f2920b0e3..a4745143a99b 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -21,6 +21,40 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) { return smem_ptr_uint; } +// LdMatrix has .x1, .x2 and .x4 options, currently we actively use .x2 and +// .x4. In .x2 option. the the address register of upper half warp (lane 16-31) +// are un-used but on Turing [sm75,sm80) architecture these un-used addresses +// need to be valid, in the sense that: +// 1. The data it points to has to be within allocated shared mem buffer. +// 2. The address needs to be aligned to 16 byte. +// See also: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix +// This function addresses 2. above by masking out the sub-16B component +// of the address in upper warp and 1. is guaranteed by ldmatrix swizzle +// util. +// This will **not** affect any functionality. This is just modification +// of unused pointers to satisfy the alignment requirement on Turing +// hardware. +// The alignment requirement is lifted on sm80+, +// so this function is a no-op on Ampere or above. +DEVICE_INLINE void adjustPartialLdMatrixAddrInTuring(unsigned& addr_in_byte) { +#if (__CUDA_ARCH__ < 800) + const unsigned thread_id = threadIdx.x; + // Upper half warp has 8 bytes offset from aligned in .x2 option + // of ldmatrix. Currently no support for .x1 so assume always + // adjust by half warp. + constexpr unsigned half_warp = 16; + // Need to adjust to 16 byte alignment, mask out un-aligned component. + constexpr unsigned mask_out = 16 - 1; + // Adjust only in upper half warp. + // use bit math to reduce strength + if (thread_id & half_warp) { + // mask out the bits where adjust_mask has 1. + addr_in_byte &= (~mask_out); + } +#endif //(__CUDA_ARCH__ < 800) +} + } // namespace util // Load Matrix (per warp instruction) is to take data from SMEM to Local Memory. @@ -36,6 +70,7 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) { DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) { uint2& val = reinterpret_cast(out); unsigned addr = util::toSmem(ptr); + util::adjustPartialLdMatrixAddrInTuring(addr); asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];" : "=r"(val.x), "=r"(val.y) : "r"(addr)); @@ -47,6 +82,7 @@ DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) { DEVICE_INLINE void ldMatrixT(Array<__half, 4, 4>& out, void const* ptr) { uint2& val = reinterpret_cast(out); unsigned addr = util::toSmem(ptr); + util::adjustPartialLdMatrixAddrInTuring(addr); asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];" : "=r"(val.x), "=r"(val.y) : "r"(addr)); diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu index 7d0c6be7c2be..c6976c197328 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu @@ -212,6 +212,74 @@ DEVICE_INLINE void initM16N16K4NT(Array* accumulator) { } // namespace Volta +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) + +namespace Turing { + +namespace util { +// MMA instruction wrappers (sm_75+): +DEVICE_INLINE void m16n8k16TN( + Array* C, + Array<__half, 8, 8>* A, + Array<__half, 4, 4>* B) { + unsigned const* _A = reinterpret_cast(A); + unsigned const* _B = reinterpret_cast(B); + unsigned* _C = reinterpret_cast(C); + const unsigned* _D = reinterpret_cast(C); + + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3]) + : "r"(_A[0]), + "r"(_A[1]), + "r"(_B[0]), + "r"(_D[0]), + "r"(_D[1]), + "r"(_D[2]), + "r"(_D[3])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3]) + : "r"(_A[2]), + "r"(_A[3]), + "r"(_B[1]), + "r"(_D[0]), + "r"(_D[1]), + "r"(_D[2]), + "r"(_D[3])); +} + +} // namespace util + +template +DEVICE_INLINE void initM16N8K16TN(Array* accumulator) { + float* _C = reinterpret_cast(accumulator); + _C[0] = 0; + _C[1] = 0; + _C[acc_stride] = 0; + _C[acc_stride + 1] = 0; +} + +template +DEVICE_INLINE void M16N8K16TN( + Array* C, + Array<__half, 8, 8>* A, + Array<__half, 4, 4>* B) { + // TODO: in a follow up, + // lift this fused swizzle onto iterdomain + float* _C = reinterpret_cast(C); + float C_data[4] = {_C[0], _C[1], _C[acc_stride], _C[acc_stride + 1]}; + + util::m16n8k16TN(reinterpret_cast*>(&C_data[0]), A, B); + + _C[0] = C_data[0]; + _C[1] = C_data[1]; + _C[acc_stride] = C_data[2]; + _C[acc_stride + 1] = C_data[3]; +} + +} // namespace Turing + +#endif // Arch 75 + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) namespace Ampere { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 258688f399dc..6cb8be261011 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -219,6 +219,7 @@ void WarpMmaSwizzler::scheduleMmaWarpOutput( setWarpMapped(tv, 5); } break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: scheduleTuringM16N8K16MmaWarpOutput(tv, options); if (tv->definition()->isA()) { @@ -240,6 +241,7 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) { case MmaOptions::MacroType::Volta_16_16_4: scheduleVoltaOperandRead(tv, options); break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: scheduleTuringOperandRead(tv, options); break; @@ -415,7 +417,8 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { : isOperandTransposed(options); // Check mma option is supported TORCH_CHECK( - options.macro == MmaOptions::MacroType::Ampere_16_8_16, + options.macro == MmaOptions::MacroType::Ampere_16_8_16 || + options.macro == MmaOptions::MacroType::Turing_16_8_16, "scheduleLdMatrix: unknown macro for ldmatrix"); if (options.operand == MmaOptions::Operand::A) { diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 2bcb905173d7..d4d13a6a1fd7 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -1055,7 +1055,7 @@ bool TensorView::isEmptyTensor() const { void TensorView::applyMmaSwizzle(MmaOptions options) { switch (options.operand) { - case MmaOptions::Operand::NotOperand: + case MmaOptions::Operand::Accumulator: mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(this, options); break; case MmaOptions::Operand::A: diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 44d5d781dadf..4bb2542f10f6 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -189,9 +189,9 @@ TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { // Schedule the output instruction tile. // Assumes last 3 dims are mnk tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Set memory type. tv0cw->setMemoryType(MemoryType::Shared); @@ -255,9 +255,9 @@ TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -323,9 +323,9 @@ TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { tv2c->reorder({{0, 2}, {1, 0}, {2, 1}}); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -550,12 +550,12 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTT_CUDA) { // Use WarpMmaSwizzler for the innermost instruction tile (Mi,Ni, Ki) on // output tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // -6 -5 -4 -3 -2 -1 // [Mwo Nwo Mw Nw Mi Ni] tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Inline broadcast with smem write. tv0b->computeAt(tv0cw, -2); @@ -705,9 +705,9 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0b->computeAt(tv0cw, -2); tv1b->computeAt(tv1cw, -2); @@ -858,9 +858,9 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0b->computeAt(tv0cw, -2); tv1b->computeAt(tv1cw, -2); @@ -944,9 +944,9 @@ TEST_F(NVFuserTest, FusionAmpereMMATN_CUDA) { tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -1018,9 +1018,9 @@ TEST_F(NVFuserTest, FusionAmpereMMATT_CUDA) { // [M,K,N] -> [M,N,K] tv2c->reorder({{-2, -1}, {-1, -2}}); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -1097,9 +1097,9 @@ TEST_F(NVFuserTest, FusionAmpereMMANT_CUDA) { // [K,M,N] -> [M,N,K] tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -1235,9 +1235,9 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTN_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10 @@ -1387,9 +1387,9 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTT_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10 @@ -1540,9 +1540,9 @@ TEST_F(NVFuserTest, FusionAmpereMatmulNT_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10 @@ -1732,9 +1732,9 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv4c->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); tv4->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); // Schedule gemm 1: // ------------------------------------------------------------------ @@ -1802,13 +1802,13 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv3c->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3cw->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3h->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3cw->setMemoryType(MemoryType::Shared); // Parallelize @@ -2034,9 +2034,9 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv4c->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); tv4->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); // Schedule gemm 1: // ------------------------------------------------------------------ @@ -2112,9 +2112,9 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // // // --------------------------------------------------------------------------- tv3c->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); // mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(tv3ccw, // mma_builder1.build()); @@ -2232,6 +2232,671 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001)); } +// MMA unit test on Turing +TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // [M,K] + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); + // [N,K] + auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cacheBefore(); + + // [M,N,K] -> [N,M,K] + tv0cr->reorder({{-2, -3}, {-3, -2}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({8, 16}, options); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// MMA unit test on Turing +TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // [M,K] + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); + // [K,N] + auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,K,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {true, false, false}); + + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TT); + + mma_builder.configureMma(tv2); + + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cacheBefore(); + + // [M,K,N] -> [N,M,K] + tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [M,K,N] -> [M,N,K] + tv1cr->reorder({{-2, -1}, {-1, -2}}); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // [M,K,N] -> [M,N,K] + tv2c->reorder({{-2, -1}, {-1, -2}}); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({16, 8}, options); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// MMA unit test on Turing +TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // [K,M] + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); + // [K,N] + auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [K,M,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, true, false}); + auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::NT); + + mma_builder.configureMma(tv2); + + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cacheBefore(); + + // [K,M,N] -> [N,M,K] + tv0cr->reorder({{-3, -1}, {-1, -3}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [K,M,N] -> [M,N,K] + tv1cr->reorder({ + {-3, -1}, + {-2, -3}, + {-1, -2}, + }); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // [K,M,N] -> [M,N,K] + tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({16, 8}, options); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// Matmul test on Turing +TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + int M = 511, N = 257, K = 88; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cacheBefore(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Matmul test on Turing +TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K = 128; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,K,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TT); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cacheBefore(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] -> [No,Ko,K,N] + tv1cw->reorder({{-2, -1}, {-1, -2}}); + tv1r->reorder({{-2, -1}, {-1, -2}}); + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Matmul test on Turing +TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K = 128; + + // [K,M] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [K,M,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, true, false}); + + auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::NT); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cacheBefore(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] -> [..., K,M] + tv0cw->reorder({{-2, -1}, {-1, -2}}); + tv0r->reorder({{-2, -1}, {-1, -2}}); + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] -> [No,Ko,K,N] + tv1cw->reorder({{-2, -1}, {-1, -2}}); + tv1r->reorder({{-2, -1}, {-1, -2}}); + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({K, M}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace jit