From a449f49d4e7a1c8bd402e369aa135e90544371dc Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 21 Feb 2022 15:32:38 -0800 Subject: [PATCH 01/57] initial volta support --- caffe2/CMakeLists.txt | 1 + test/cpp/jit/CMakeLists.txt | 1 + test/cpp/jit/test_gpu_tensorcore.cpp | 881 ++++++++++++++++++ tools/build_variables.bzl | 3 + torch/csrc/jit/codegen/cuda/arith.cpp | 128 +++ torch/csrc/jit/codegen/cuda/arith.h | 22 + torch/csrc/jit/codegen/cuda/codegen.cpp | 90 +- torch/csrc/jit/codegen/cuda/dispatch.cpp | 15 + torch/csrc/jit/codegen/cuda/dispatch.h | 4 + .../csrc/jit/codegen/cuda/executor_utils.cpp | 2 + torch/csrc/jit/codegen/cuda/index_compute.cpp | 7 +- torch/csrc/jit/codegen/cuda/ir_builder.cpp | 1 + torch/csrc/jit/codegen/cuda/ir_cloner.cpp | 4 + torch/csrc/jit/codegen/cuda/ir_cloner.h | 1 + .../jit/codegen/cuda/ir_interface_nodes.h | 23 + .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 111 +++ torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 5 + torch/csrc/jit/codegen/cuda/ir_iostream.h | 1 + torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 70 +- torch/csrc/jit/codegen/cuda/ir_utils.cpp | 23 + torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 2 +- .../jit/codegen/cuda/lower_allocation.cpp | 2 + torch/csrc/jit/codegen/cuda/lower_index.cpp | 9 + torch/csrc/jit/codegen/cuda/lower_index.h | 2 + .../csrc/jit/codegen/cuda/lower_predicate.cpp | 21 + torch/csrc/jit/codegen/cuda/lower_utils.cpp | 41 + torch/csrc/jit/codegen/cuda/lower_utils.h | 2 + .../jit/codegen/cuda/lower_validation.cpp | 43 +- torch/csrc/jit/codegen/cuda/mma_type.cpp | 137 +++ torch/csrc/jit/codegen/cuda/mma_type.h | 132 +++ torch/csrc/jit/codegen/cuda/mutator.cpp | 18 + .../csrc/jit/codegen/cuda/root_domain_map.cpp | 6 + torch/csrc/jit/codegen/cuda/root_domain_map.h | 5 + .../jit/codegen/cuda/runtime/tensorcore.cu | 215 +++++ .../jit/codegen/cuda/scheduler/mma_utils.cpp | 420 +++++++++ .../jit/codegen/cuda/scheduler/mma_utils.h | 144 +++ .../csrc/jit/codegen/cuda/scheduler/utils.cpp | 92 ++ torch/csrc/jit/codegen/cuda/scheduler/utils.h | 17 + torch/csrc/jit/codegen/cuda/tensor_view.cpp | 16 + torch/csrc/jit/codegen/cuda/type.cpp | 2 + torch/csrc/jit/codegen/cuda/type.h | 1 + 41 files changed, 2710 insertions(+), 10 deletions(-) create mode 100644 test/cpp/jit/test_gpu_tensorcore.cpp create mode 100644 torch/csrc/jit/codegen/cuda/mma_type.cpp create mode 100644 torch/csrc/jit/codegen/cuda/mma_type.h create mode 100644 torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu create mode 100644 torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp create mode 100644 torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 54e1af80c758..d55a0c128d56 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -974,6 +974,7 @@ if(USE_CUDA OR USE_ROCM) ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/tensor.cu ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/welford.cu ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/warp.cu + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/tensorcore.cu ${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh ${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/UnpackRaw.cuh ) diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index cfdbb28a6765..df579591dfd8 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -95,6 +95,7 @@ set(JIT_TEST_SRCS if(USE_CUDA) list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu.cpp) list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_shift.cpp) + list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_tensorcore.cpp) endif() add_executable(test_jit diff --git a/test/cpp/jit/test_gpu_tensorcore.cpp b/test/cpp/jit/test_gpu_tensorcore.cpp new file mode 100644 index 000000000000..7f0e0ac40dd9 --- /dev/null +++ b/test/cpp/jit/test_gpu_tensorcore.cpp @@ -0,0 +1,881 @@ +#if defined(USE_CUDA) +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// fuser and IR parser +#include "test_gpu_validator.h" + +#include +#include +#include + +#include +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +using namespace torch::jit::fuser::cuda; +using namespace at::indexing; + +namespace { + +// Make a tensor that is known to be fully contiguous of dimensionality=ndims, +// but unknown sizes +TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { + return TensorViewBuilder() + .ndims(ndims) + .dtype(dtype) + .contiguity(std::vector(ndims, true)) + .build(); +} + +// Make a tensor that is known to be non-contiguous of dimensionality=ndims, +// but unknown sizes +TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) { + return TensorViewBuilder().ndims(ndims).dtype(dtype).build(); +} + +// Make a non-contiguous tensor of compile-time known sizes +TensorView* makeConcreteTensor( + std::vector shape, + DataType dtype = DataType::Float) { + return TensorViewBuilder().shape(shape).dtype(dtype).build(); +} + +void checkIntValue( + ExpressionEvaluator& evaluator, + Val* val, + Int::ScalarType expected_value) { + TORCH_CHECK(val->isAnInt()); + const auto actual_value = evaluator.evaluate(val); + TORCH_CHECK(actual_value.has_value()); + TORCH_CHECK(actual_value.value() == expected_value); +} + +void checkIntValue( + kir::ExpressionEvaluator& evaluator, + const Val* val, + Int::ScalarType expected_value) { + const auto actual_value = evaluator.evaluate(val); + TORCH_CHECK(actual_value.has_value()); + TORCH_CHECK(actual_value.value() == expected_value); +} + +bool cudaArchGuardShouldSkip(int required_major, int required_minor) { + int capability_major = at::cuda::getCurrentDeviceProperties()->major; + int capability_minor = at::cuda::getCurrentDeviceProperties()->minor; + + if (capability_major < required_major || + (capability_major == required_major && + capability_minor < required_minor)) { + return true; + } + return false; +} + +#define NVFUSER_TEST_CUDA_ARCH_GUARD(REQUIRED_MAJOR, REQUIRED_MINOR) \ + if (cudaArchGuardShouldSkip(REQUIRED_MAJOR, REQUIRED_MINOR)) { \ + GTEST_SKIP() << "Requires GPU capability above " << REQUIRED_MAJOR << "." \ + << REQUIRED_MINOR << " to run.\n"; \ + } + +} // namespace + +// MMA unit test for a single instruction tile. VoltaTT +TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + + // [M,K] + auto tv0 = makeConcreteTensor({16, 4}, DataType::Half); + // [K,N] + auto tv1 = makeConcreteTensor({4, 16}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,K,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b,{1}); + + fusion.addOutput(tv2); + + // TODO: should be able to completely remove it + // in a follow up. + GemmTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 16, 4); + gemm_tile.warp_tile = GemmTile(16, 16, 4); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TT); + tv2->configureMma(mma_builder.build()); + + // Write A to smem + auto tv0cw = tv0b->cache_after(); + // Read A from smem + auto tv0cr = tv0cw->cache_after(); + + // Write B to smem + auto tv1cw = tv1b->cache_after(); + + // Read B from smem + auto tv1cr = tv1cw->cache_after(); + + // Register accumulator + auto tv2c = tv2->cache_before(); + + // [M,K,N]->[M,N,K] + tv0cr->reorder({{-2, -1}, {-1, -2}}); + + // Schedule the instruction tile loops, which is the only + // part we have in this unit test. + // Assumes last 3 dims are mnk + // The innermost loops are dictated by the type of mma used, + // the scheduler needs to use mma_util::WarpMmaSwizzler to + // get the right thread swizzle. Currently this is the only + // method allowed to schedule the 3/2 inner most loops of + // mma input/output. + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [M,K,N]->[M,N,K] + tv1cr->reorder({{-2, -1}, {-1, -2}}); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // [M,K,N]->[M,N,K] + tv2c->reorder({{-2, -1}, {-1, -2}}); + + // Schedule the output instruction tile. + // Assumes last 3 dims are mnk + tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + // Set memory type. + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 4}, options); + auto t1 = at::randn({4, 16}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion,{t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// MMA unit test for a single instruction tile. VoltaTN +TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + + // [M,K] + auto tv0 = makeConcreteTensor({16, 4}, DataType::Half); + // [N,K] + auto tv1 = makeConcreteTensor({16, 4}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b,{2}); + + fusion.addOutput(tv2); + + // TODO: should be able to completely remove it + // in a follow up. + GemmTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 16, 4); + gemm_tile.warp_tile = GemmTile(16, 16, 4); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); + + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + tv2->configureMma(mma_builder.build()); + + auto tv0cw = tv0b->cache_after(); + auto tv0cr = tv0cw->cache_after(); + auto tv1cw = tv1b->cache_after(); + auto tv1cr = tv1cw->cache_after(); + auto tv2c = tv2->cache_before(); + + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 4}, options); + auto t1 = at::randn({16, 4}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion,{t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// MMA unit test for a single instruction tile. VoltaNT +TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); + Fusion fusion; + FusionGuard fg(&fusion); + + // [K,M] + auto tv0 = makeConcreteTensor({4, 16}, DataType::Half); + // [K,N] + auto tv1 = makeConcreteTensor({4, 16}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [K,M,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, true, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b,{0}); + + fusion.addOutput(tv2); + + GemmTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 16, 4); + gemm_tile.warp_tile = GemmTile(16, 16, 4); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); + + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(MmaOptions::MmaInputLayout::NT); + + tv2->configureMma(mma_builder.build()); + + auto tv0cw = tv0b->cache_after(); + auto tv0cr = tv0cw->cache_after(); + auto tv1cw = tv1b->cache_after(); + auto tv1cr = tv1cw->cache_after(); + auto tv2c = tv2->cache_before(); + + // To MNK + tv0cr->reorder({{0, 2}, {1, 0}, {2, 1}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // To MNK + tv1cr->reorder({{0, 2}, {1, 0}, {2, 1}}); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + tv2c->reorder({{0, 2}, {1, 0}, {2, 1}}); + tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({4, 16}, options); + auto t1 = at::randn({4, 16}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion,{t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// Gemm test for Volta MMA: TT +// This is the only example that is fully manual, +// the rest of them are facilitated by gemm utils. +TEST_F(NVFuserTest, FusionVoltaMatMulTT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + + // Keep multiples of 8 to keep vectorizable. + int M = 264, N = 120, K = 248; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,K,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {true, false, false}); + + auto tv2 = fusedMultiplySum(tv0b, tv1b,{1}); + + fusion.addOutput(tv2); + + GemmTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); + + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TT); + + tv2->configureMma(mma_builder.build()); + + auto tv0r = tv0->cache_after(); + auto tv1r = tv1->cache_after(); + auto tv0cw = tv0b->cache_after(); + auto tv0cr = tv0cw->cache_after(); + auto tv1cw = tv1b->cache_after(); + auto tv1cr = tv1cw->cache_after(); + auto tv2c = tv2->cache_before(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + + // -3 -2 -1 + //[... M, N, K] + // Distribute warp tile: accumulator reg + tv2c->split(-3, gemm_tile.warp_tile.m); + tv2c->split(-2, gemm_tile.warp_tile.n); + + // -5 -4 -3 -2 -1 + // [Mwo Mw Nwo Nw K] + tv2c->split(-4, gemm_tile.instruction_tile.m); + tv2c->split(-2, gemm_tile.instruction_tile.n); + tv2c->split(-1, gemm_tile.instruction_tile.k); + + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Mw Mi Nwo Nw Ni Ko Ki] + tv2c->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}}); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Nwo Ko Mw Nw Mi Ni Ki] + + // Distribute warp tile: output tensor + tv2->split(-2, gemm_tile.warp_tile.m); + tv2->split(-1, gemm_tile.warp_tile.n); + + // -4 -3 -2 -1 + // [Mwo Mw Nwo Nw ] + tv2->split(-3, gemm_tile.instruction_tile.m); + tv2->split(-1, gemm_tile.instruction_tile.n); + + // -6 -5 -4 -3 -2 -1 + // [Mwo Mw Mi Nwo Nw Ni] + tv2->reorder({{-5, -4}, {-4, -2}, {-3, -5}, {-2, -3}}); + // -6 -5 -4 -3 -2 -1 + // [Mwo Nwo Mw Nw Mi Ni] + + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,No,Ko,M,N,K] + tv0cw->reorder({ + {-3, -2}, + {-2, -3}, + }); + // [Mo,No,Ko,N,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + auto warp_dims = gemm_tile.cta_tile / gemm_tile.warp_tile; + int num_of_thread = warp_dims.m * warp_dims.n * warp_dims.k * 32; + int vector_word = 8; + + // Smem write + tv0cw->split(-1, num_of_thread * vector_word); + tv0cw->split(-1, 8); + // [..., thread, vec] + // distribute to warp: + tv0cw->split(-2, 32); + tv0cw->split(-3, warp_dims.n * warp_dims.k); + + tv0cw->axis(-1)->parallelize(ParallelType::Vectorize); + tv0cw->axis(-2)->parallelize(ParallelType::TIDx); + tv0cw->axis(-3)->parallelize(ParallelType::TIDy); + tv0cw->axis(-4)->parallelize(ParallelType::TIDz); + + // Gmem read (reg staging) + tv0r->split(-1, num_of_thread * vector_word); + tv0r->split(-1, 8); + // [..., thread, vec] + // distribute to warp: + tv0r->split(-2, 32); + tv0r->split(-3, warp_dims.n * warp_dims.k); + + tv0r->axis(-1)->parallelize(ParallelType::Vectorize); + tv0r->axis(-2)->parallelize(ParallelType::TIDx); + tv0r->axis(-3)->parallelize(ParallelType::TIDy); + tv0r->axis(-4)->parallelize(ParallelType::TIDz); + + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [Mo,No,Ko,M,N,K] + tv1r->reorder({ + {-1, -2}, + {-2, -1}, + }); + tv1cw->reorder({ + {-1, -2}, + {-2, -1}, + }); + // [Mo,No,Ko,M,K,N] + tv1cw->merge(-2); + tv1r->merge(-2); + // [Mo,No,Ko,i,wy,wx,v] + tv1r->split(-1, num_of_thread * vector_word); + tv1r->split(-1, 8); + // [..., thread, vec] + // distribute to warp: + tv1r->split(-2, 32); + tv1r->split(-3, warp_dims.n * warp_dims.k); + + tv1r->axis(-1)->parallelize(ParallelType::Vectorize); + tv1r->axis(-2)->parallelize(ParallelType::TIDx); + tv1r->axis(-3)->parallelize(ParallelType::TIDy); + tv1r->axis(-4)->parallelize(ParallelType::TIDz); + + tv1cw->split(-1, num_of_thread * vector_word); + tv1cw->split(-1, 8); + // [..., thread, vec] + // distribute to warp: + tv1cw->split(-2, 32); + tv1cw->split(-3, warp_dims.n * warp_dims.k); + + tv1cw->axis(-1)->parallelize(ParallelType::Vectorize); + tv1cw->axis(-2)->parallelize(ParallelType::TIDx); + tv1cw->axis(-3)->parallelize(ParallelType::TIDy); + tv1cw->axis(-4)->parallelize(ParallelType::TIDz); + + tv1cw->setMemoryType(MemoryType::Shared); + + // Schedule mma input + // --------------------------------------------------------------------------- + + // Use WarpMmaSwizzler for the innermost instruction tile.(Mi, Ni, Ki) + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + // Use WarpMmaSwizzler for the innermost instruction tile (Mi,Ni, Ki) on + // output + tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + // -6 -5 -4 -3 -2 -1 + // [Mwo Nwo Mw Nw Mi Ni] + tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + // Inline broadcast with smem write. + tv0b->computeAt(tv0cw, -2); + tv1b->computeAt(tv1cw, -2); + + // Vectorize smem read + tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); + tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion,{t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Gemm test for Volta MMA: TN +TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 120, N = 264, K = 56; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + GemmTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); + + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + tv2->configureMma(mma_builder.build()); + + auto tv0r = tv0->cache_after(); + auto tv1r = tv1->cache_after(); + auto tv0cw = tv0b->cache_after(); + auto tv0cr = tv0cw->cache_after(); + auto tv1cw = tv1b->cache_after(); + auto tv1cr = tv1cw->cache_after(); + auto tv2c = tv2->cache_before(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,No,Ko,M,N,K] + tv0cw->reorder({ + {-3, -2}, + {-2, -3}, + }); + // [Mo,No,Ko,N,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [Mo,No,Ko,M,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [Mo,No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + tv0b->computeAt(tv0cw, -2); + tv1b->computeAt(tv1cw, -2); + + tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); + tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion,{t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Gemm test for Volta MMA: NT +TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 240, N = 320, K = 136; + + // [K,M] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [K,M,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, true, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b,{0}); + + fusion.addOutput(tv2); + + GemmTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); + + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(MmaOptions::MmaInputLayout::NT); + + tv2->configureMma(mma_builder.build()); + + auto tv0r = tv0->cache_after(); + auto tv1r = tv1->cache_after(); + auto tv0cw = tv0b->cache_after(); + auto tv0cr = tv0cw->cache_after(); + auto tv1cw = tv1b->cache_after(); + auto tv1cr = tv1cw->cache_after(); + auto tv2c = tv2->cache_before(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,No,Ko,M,N,K] + tv0cw->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); + // [Mo,No,Ko,N,K,M] + tv0cw->merge(-2); + + // [Mo,No,M,K] + tv0r->reorder({{-2, -1}, {-1, -2}}); + // [Mo,No,K,M] + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [Mo,No,Ko,M,N,K] + tv1cw->reorder({{-2, -1}, {-1, -2}}); + tv1r->reorder({{-2, -1}, {-1, -2}}); + // [Mo,No,Ko,M,K,N] + tv1cw->merge(-2); + tv1r->merge(-2); + // [Mo,No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + // [...M,N,K] + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + tv0b->computeAt(tv0cw, -2); + tv1b->computeAt(tv1cw, -2); + + tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); + tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({K, M}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion,{t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto tref = t0.to(at::kFloat).t().matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +#undef NVFUSER_TEST_CUDA_ARCH_GUARD + +} // namespace jit +} // namespace torch + +#endif \ No newline at end of file diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 4d674ae828da..0a8ecf9f6534 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -45,6 +45,7 @@ libtorch_nvfuser_runtime_sources = [ "torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu", "torch/csrc/jit/codegen/cuda/runtime/helpers.cu", "torch/csrc/jit/codegen/cuda/runtime/index_utils.cu", + "torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu", "torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu", "torch/csrc/jit/codegen/cuda/runtime/tensor.cu", "torch/csrc/jit/codegen/cuda/runtime/welford.cu", @@ -695,6 +696,8 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/transform_view.cpp", "torch/csrc/jit/codegen/cuda/type.cpp", "torch/csrc/jit/codegen/cuda/utils.cpp", + "torch/csrc/jit/codegen/cuda/mma_type.cpp", + "torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp", "torch/csrc/jit/passes/frozen_conv_add_relu_fusion_cuda.cpp", "torch/csrc/jit/tensorexpr/cuda_codegen.cpp", "torch/csrc/jit/runtime/register_cuda_ops.cpp", diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 89b2e4f0409f..86fe7720e4b1 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -1493,6 +1493,134 @@ TensorView* gather( return out_tv; } +namespace { + +//! Create new output for mma +static TensorView* newForMma( + TensorView* tv_a, + TensorView* tv_b, + const std::vector& axes, + DataType data_type = DataType::Float) { + auto orig_domain_a = + TensorDomain::noReductions(tv_a->getMaybeRFactorDomain()); + auto orig_domain_b = + TensorDomain::noReductions(tv_b->getMaybeRFactorDomain()); + + TORCH_INTERNAL_ASSERT( + orig_domain_a.size() == orig_domain_b.size(), + "MMA op: need matching dim input"); + + std::set axes_set(axes.begin(), axes.end()); + std::vector new_domain; + + TORCH_INTERNAL_ASSERT( + !axes_set.empty(), + "Asked for ouput of reduction, but no reduction axis provided."); + + TORCH_INTERNAL_ASSERT( + (*(axes_set.rbegin())) < orig_domain_a.size(), + "Error setting up reduction, reduction axis (", + *(axes_set.rbegin()), + ") is outside nDims (", + orig_domain_a.size(), + "). Keep in mind reductions are relative to root domains, not modified views."); + + auto axis_iter = axes_set.begin(); + for (const auto dim : c10::irange(orig_domain_a.size())) { + bool isReduction = false; + if (axis_iter != axes_set.end() && *axis_iter == dim) { + isReduction = true; + axis_iter++; + } + + const IterDomain* id = orig_domain_a[dim]->isBroadcast() + ? orig_domain_b[dim] + : orig_domain_a[dim]; + + TORCH_CHECK( + !(isReduction && id->isBroadcast() && !id->isImplicitBroadcast()), + "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ", + id, + " of tensor ", + tv_a, + "and", + tv_b); + + new_domain.push_back(IrBuilder::create( + id->start(), + id->extent(), + id->stopOffset(), + ParallelType::Serial, + isReduction ? IterType::Reduction : id->getIterType())); + } + + TensorDomain* td = IrBuilder::create( + new_domain, std::vector(new_domain.size(), true)); + + return IrBuilder::create(td, data_type); +} + +} // namespace + +TensorView* fusedMultiplySum( + TensorView* tv_a, + TensorView* tv_b, + const std::vector& axes, + Val* init) { + if (init == nullptr) { + init = IrBuilder::create(0); + } + + // TODO: + // We will want to support initialize and rfactor with + // mma as well, for maybe fusing bias in prolog. + // TODO: check init type if given a tv, + // not supported currently though. + TORCH_CHECK( + init->isConstScalar(), + "Cannot create a reduction operation where the initial value is not a const scalar."); + + // TODO: + // Validate axis relationships between a and b + TORCH_CHECK(tv_a->nDims() > 0, "Tried to reduce a 0-dim tensor"); + + // TODO: + // Add tf32 + TORCH_CHECK(tv_a->getDataType().value() == DataType::Half); + TORCH_CHECK(tv_b->getDataType().value() == DataType::Half); + + TORCH_CHECK(axes.size() > 0, "No reduction axis specified"); + + // TODO: + // will lift this in a follow up when we have a + // more generic axes matching. + TORCH_CHECK( + axes.size() == 1, "Single axis reduction only for mma op instantiation.") + + std::vector uint_axes; + const int ndims = tv_a->domain()->noReductions().size(); + for (int axis : axes) { + if (axis < 0) { + axis += ndims; + } + + TORCH_CHECK( + axis >= 0 && axis < ndims, + "Reduction on invalid axis, recieved: ", + axis, + " however tensor view only has ", + ndims, + " non-reduction dims."); + + uint_axes.push_back((unsigned int)axis); + } + + TensorView* out = newForMma(tv_a, tv_b, uint_axes); + IrBuilder::create(out, tv_a, tv_b, init); + + return out; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index 1f18f65666ad..3451bd0c12b5 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -561,6 +561,28 @@ TORCH_CUDA_CU_API TensorView* gather( const std::vector& strides = {}, bool trim_out_of_bounds = false); +//! A fused pointwise multiply and sum +//! operator that instantiates the following +//! fused pattern: +//! c = mul(tv_a, tv_b); +//! return sum(c, axes) +//! +//! \param tv_a first multiply operand +//! \param tv_b first multiply operand +//! \param axes axes to sum over +//! \param init sum initial value +//! +//! Note & TODO: +//! currently only support lowering to a mma op +//! through this interface. +//! will support converting back to multiply and reduce in +//! a follow up. +TORCH_CUDA_CU_API TensorView* fusedMultiplySum( + TensorView* tv_a, + TensorView* tv_b, + const std::vector& axes, + Val* init = nullptr); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 16cc459d7a0d..a71a0591a475 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -359,6 +360,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { bool is_vector_op = false; size_t vector_word_size = 1; + if (uop->out()->isA()) { + if (auto mma = dynamic_cast( + uop->out()->as()->view()->definition())) { + genMmaInitialization(mma, uop); + return; + } + } + if (vectorize_scope_ && uop->out()->isA()) { auto ti = uop->out()->as(); @@ -721,6 +730,74 @@ class CudaKernelGenerator : private OptOutConstDispatch { } } + std::string genArchString(MmaOptions options) { + std::stringstream ss; + if (isVolta(options.macro)) { + ss << "Volta"; + } else if (isTuring(options.macro)) { + ss << "Turing"; + } else if (isAmpere(options.macro)) { + ss << "Ampere"; + } else { + TORCH_INTERNAL_ASSERT(false, "mma macro unknown arch"); + } + return ss.str(); + } + + std::string genMmaOp(const MmaOp* mma, bool init = false) { + std::stringstream ss; + auto options = mma->options(); + ss << genArchString(options) << "::"; + if (init) { + ss << "init"; + } + ss << toString(options.macro) << toString(options.operand_layout); + // TODO: additional parameter could be removed by swizzling iterdomain + auto acc_stride = mma->accStride(); + TORCH_INTERNAL_ASSERT(acc_stride > 0); + ss << "<" << acc_stride << ">"; + return ss.str(); + } + + void genMmaOperands(const MmaOp* mma) { + std::stringstream ss; + auto options = mma->options(); + auto in_a = mma->inA()->as()->view(); + auto dtype = in_a->getDataType().value(); + indent() << kTab << "reinterpret_cast*>(&" + << gen(mma->inA()) << "),\n"; + indent() << kTab << "reinterpret_cast*>(&" + << gen(mma->inB()) << ")"; + } + + void genMmaInitialization(const MmaOp* mma, const UnaryOp* uop) { + auto options = mma->options(); + + indent() << genMmaOp(mma, true) << "(reinterpret_castout()->getDataType().value() << "," + << getOutputRegisterSize(mma->options().macro) << "," + << getOutputRegisterSize(mma->options().macro) << ">*>" + << "(&" << gen(uop->out()) << "));\n"; + } + + void handle(const MmaOp* mma) final { + auto options = mma->options(); + auto in_a = mma->inA()->as(); + auto out = mma->out()->as(); + indent() << genMmaOp(mma) << "(\n"; + indent() << kTab << "reinterpret_castview()->getDataType().value() << "," + << getOutputRegisterSize(options.macro) << "," + << getOutputRegisterSize(options.macro) << ">*>(&" + << gen(mma->out()) << "),\n"; + genMmaOperands(mma); + code_ << ");\n"; + } + std::string genReductionOp(BinaryOpType op_type, Val* out) { std::stringstream lambda; DataType data_type = out->dtype(); @@ -1291,8 +1368,17 @@ class CudaKernelGenerator : private OptOutConstDispatch { case MemoryType::Shared: if (kir::ExpressionEvaluator::isConst(size)) { // Static shared memory - indent() << "__shared__ " << buffer_dtype << " " << varName(tv) - << "[" << genInline(size) << "];\n"; + indent(); + + // align to 16B if any access is vectorized + // TODO: + // This is a WAR to support vectorized access, + // eventually want to always use dynamic smem alloc. + if (hasVectorizedAccess(alloc)) { + code_ << "__align__(16) "; + } + code_ << "__shared__ " << buffer_dtype << " " << varName(tv) << "[" + << genInline(size) << "];\n"; } else { // Align Offset Position indent() << "offset = alignBufferSize(offset," diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 8f39da0bc818..5f91206e5b86 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -104,6 +104,9 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::WelfordOp: ptr(handler)->handle(expr->as()); return; + case ExprType::MmaOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::BroadcastOp: ptr(handler)->handle(expr->as()); return; @@ -233,6 +236,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::WelfordOp: ptr(handler)->handle(expr->as()); return; + case ExprType::MmaOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::BroadcastOp: ptr(handler)->handle(expr->as()); return; @@ -373,6 +379,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) { case ExprType::WelfordOp: ptr(mutator)->mutate(expr->as()); return; + case ExprType::MmaOp: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::BroadcastOp: ptr(mutator)->mutate(expr->as()); return; @@ -578,6 +587,9 @@ void OptOutConstDispatch::handle(const ReductionOp* stmt) { void OptOutConstDispatch::handle(const WelfordOp* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const MmaOp* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const BroadcastOp* stmt) { unhandled(stmt); } @@ -680,6 +692,9 @@ void OptOutDispatch::handle(ReductionOp* stmt) { void OptOutDispatch::handle(WelfordOp* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(MmaOp* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(BroadcastOp* stmt) { unhandled(stmt); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index a35efec48a06..b1d0a10b5f79 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -73,6 +73,7 @@ class BinaryOp; class TernaryOp; class ReductionOp; class WelfordOp; +class MmaOp; class BroadcastOp; class TransposeOp; class ShiftOp; @@ -133,6 +134,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const TernaryOp* stmt); virtual void handle(const ReductionOp* stmt); virtual void handle(const WelfordOp* stmt); + virtual void handle(const MmaOp* stmt); virtual void handle(const BroadcastOp* stmt); virtual void handle(const Split* stmt); @@ -182,6 +184,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(TernaryOp* stmt); virtual void handle(ReductionOp* stmt); virtual void handle(WelfordOp* stmt); + virtual void handle(MmaOp* stmt); virtual void handle(BroadcastOp* stmt); virtual void handle(Split* stmt); @@ -272,6 +275,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(TernaryOp*); virtual void mutate(ReductionOp*); virtual void mutate(WelfordOp*); + virtual void mutate(MmaOp*); virtual void mutate(BroadcastOp*); virtual void mutate(Split*); diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 4c2d3c729bf0..d38a4356e390 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -90,6 +91,7 @@ std::string kernelPreamble() { ss << nvfuser_resources::broadcast_cu; ss << nvfuser_resources::welford_cu; ss << nvfuser_resources::warp_cu; + ss << nvfuser_resources::tensorcore_cu; // Random utilities ss << nvfuser_resources::PhiloxCudaStateRaw_cu; diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 3b7c16677a8c..5e1fc232d986 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1102,7 +1102,8 @@ indexMapFromTV( // Similarly for local memory tensors, zero replacement can be // only done when there's a matching domain with the same // parallel type - (loop->iter_domain()->isThread() && is_local && same_parallel_type)) { + (loop->iter_domain()->isThread() && is_local && + (same_parallel_type || loop->iter_domain()->isInstruction()))) { idx = GpuLower::current()->kernel()->zeroVal(); zero_loops.insert(loop); } else { @@ -2075,8 +2076,8 @@ std::vector Index::getNonGlobalConsumerStridedIndices( auto alloc_info = loop_utils::getAllocInformation(consumer_tv, loops); std::unordered_map loop_to_ind_map; std::unordered_set zero_loops; - std::tie(loop_to_ind_map, zero_loops) = - indexMapFromTV(consumer_tv, loops, alloc_info.init_for_loop, true); + std::tie(loop_to_ind_map, zero_loops) = indexMapFromTV( + consumer_tv, loops, alloc_info.init_for_loop, true, nullptr); ensureStaticIndexing(consumer_tv, alloc_info.init_for_loop, loops); diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp index 695d6377cb73..fcf65247263b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -62,6 +62,7 @@ IR_BUILDER_INSTANTIATE(BinaryOp) IR_BUILDER_INSTANTIATE(TernaryOp) IR_BUILDER_INSTANTIATE(ReductionOp) IR_BUILDER_INSTANTIATE(WelfordOp) +IR_BUILDER_INSTANTIATE(MmaOp) IR_BUILDER_INSTANTIATE(BroadcastOp) Val* IrBuilder::newResult(DataType dtype) { diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 6ed4c27b7c68..8adc3885b928 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -112,6 +112,10 @@ void IrCloner::handle(const WelfordOp* op) { clone_ = IrBuilder::clone(op, this); } +void IrCloner::handle(const MmaOp* op) { + clone_ = IrBuilder::clone(op, this); +} + void IrCloner::handle(const TransposeOp* op) { clone_ = IrBuilder::clone(op, this); } diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index d62fa6769f86..123b69724a87 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -74,6 +74,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { void handle(const BroadcastOp*) override; void handle(const ReductionOp*) override; void handle(const WelfordOp*) override; + void handle(const MmaOp*) override; void handle(const TransposeOp*) override; void handle(const ShiftOp*) override; void handle(const GatherOp*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 0974c25efde8..e796512263a8 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -433,6 +433,29 @@ class TORCH_CUDA_CU_API TensorView : public Val { return is_double_buffered_; } + //! Fill in mma options in scheduling time. + //! Each mma op in Fusion IR must be configured once before lowering. + //! Mma options are configuration parameters used in lowering to mma + //! instrinsics, mainly the type of mma macro to use and input data layout + //! etc. + //! + //! TODO: This step will very likely be removed in a follow up PR. All of + //! the options configured here could actually be inferred from fusion IR + //! once we are feature complete. + void configureMma(MmaOptions options) { + TORCH_CHECK(definition(), "configureMma: invalid for input tensor ", this); + auto mma = dynamic_cast(definition()); + TORCH_CHECK(mma, "configureMma: invalid for non-mma output: ", this); + mma->configureOptions(options); + } + + //! Transforms the innermost iterdomains according to the given mma swizzle, + //! this should be used on the tvs that are either inputs/outputs of an + //! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to + //! have a matching thread swizzle with the mma operand/result. + //! More detail on usage see [WarpMmaSwizzler] in scheduler/mma_utils.h . + void applyMmaSwizzle(MmaOptions options); + friend TORCH_CUDA_CU_API TransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index bb494148be21..90ca0a8a6e5f 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -5,6 +5,7 @@ #include #include #include +#include #include //! Nodes in here should generally not be used by users. They should be behind @@ -24,6 +25,12 @@ class ViewTransform; class Scope; class IrCloner; +namespace mma_util { + +class WarpMmaSwizzler; + +} // namespace mma_util + //! Returns true if both v1 and v2 are scalars, are the same type of scalars, //! and dispatches to the inherited Val type's `->sameAs` call. e.g. if both //! vals are `Int` will dispatch to v1->as()->sameAs(v2.as()) @@ -262,6 +269,53 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { Val* const in_N_; }; +//! Fused Matmul operation +class TORCH_CUDA_CU_API MmaOp : public Expr { + public: + MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init); + + MmaOp(const MmaOp* src, IrCloner* ir_cloner); + + Val* out() const { + return out_; + } + + Val* inA() const { + return in_a_; + } + + Val* inB() const { + return in_b_; + } + + Val* init() const { + return init_; + } + + const auto& options() const { + TORCH_INTERNAL_ASSERT(options_.has_value(), "MmaOp not configured:", this); + return options_.value(); + } + + bool sameAs(const Statement* const other) const override; + + auto accStride() const { + TORCH_INTERNAL_ASSERT(options_.has_value(), "MmaOp not configured:", this); + return options_->accumulator_stride; + } + + void configureOptions(MmaOptions options) { + options_ = options; + } + + private: + Val* const out_; + Val* const in_a_; + Val* const in_b_; + Val* const init_; + c10::optional options_; +}; + class TORCH_CUDA_CU_API TransposeOp : public Expr { public: TransposeOp( @@ -662,10 +716,59 @@ class TORCH_CUDA_CU_API IterDomain : public Val { return definition() == nullptr; } + bool isInstruction() const { + return is_instruction_; + } + + bool isMmaSwizzled() const { + return is_mma_swizzled_; + } + + //! Used by WarpMmaSwizzler, marks that this id represents a + //! instruction loop, mma use only. + //! + //! An instruction loop can be considered a generalization of + //! vectorization. It also represents a loop that's implemented + //! by an instruction and should not be realized by codegen and + //! cannot be inlined with. + //! As an example, if a mma macro, call it mma_eg implements: + //! for m in M + //! for n in N + //! for k in K + //! C[m,n] += A[m,k]*B[k,n], + //! But the generated code should simply be: + //! mma_eg(C,A,B) + //! without the 3 level loopnest, i.e. they're instruction loops. + //! + //! In the actual mma macros, the loopnests it implements is a + //! transformed version of above to match the mma swizzle. + //! So it's different implicit loopnest for different macros. + //! WarpMmaSwizzler will label the instruction loops case-by-case. + void toInstruction() { + is_instruction_ = true; + } + + //! Used by WarpMmaSwizzler, this is an utility for WarpMmaSwizzler + //! to lock the thread swizzled iterdomains. + //! Only true for the iterdomains produced by WarpMmaSwizzler. + //! Mma ops require specific swizzle patterns + //! and this label utility is to prevent any further transform on the + //! iterdomains involved in the swizzle so that the pattern remain correct in + //! generated code. + //! + //! Note: + //! Used only through WarpMmaSwizzler only and mma validation relies on + //! this + //! flag being set on the correct iterdomains. + void toMmaSwizzled() { + is_mma_swizzled_ = true; + } + protected: friend TensorDomain; friend ReplayTransformations; friend IndexReferenceReplay; + friend mma_util::WarpMmaSwizzler; private: //! Valid range is defined as [start:-stop_offset] @@ -682,6 +785,14 @@ class TORCH_CUDA_CU_API IterDomain : public Val { // TODO: Remove only used in kernel IR because IterDomains don't maintain // definitions of split/merge. bool is_simple_ = true; + + // Tracks if this id is implicit in an instruction, i.e mma + bool is_instruction_ = false; + + //! Tracks if this id represents a thread swizzled loop or + //! models an implicit loop within instructions. Should not make + //! any changes once an id is warp mapped. + bool is_mma_swizzled_ = false; }; //! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 0c4fa3a0a242..7b1d00003e2e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -454,6 +454,11 @@ void IrPrinter::handle(const ShiftOp* sop) { << "}, {" << sop->padWidth() << "} )\n"; } +void IrPrinter::handle(const MmaOp* mma) { + indent() << mma->out() << " = mma(" << mma->inA() << "," << mma->inB(); + os_ << ")\n"; +} + void IrPrinter::handle(const GatherOp* op) { indent() << op->out() << " = gather( " << op->in() << ", {"; bool no_comma = true; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index 29900dd76528..02d8f2d3f2df 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -87,6 +87,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const TernaryOp*) final; void handle(const ReductionOp*) final; void handle(const WelfordOp*) final; + void handle(const MmaOp*) final; void handle(const BroadcastOp*) final; void handle(const TransposeOp*) final; void handle(const ShiftOp*) final; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 975050a986ac..66cf2397c405 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -525,6 +525,57 @@ bool WelfordOp::sameAs(const Statement* other) const { return false; } +MmaOp::MmaOp( + IrBuilderPasskey passkey, + Val* out, + Val* in_a, + Val* in_b, + Val* init) + : Expr(passkey, ExprType::MmaOp), + out_(out), + in_a_(in_a), + in_b_(in_b), + init_(init) { + // Check output type + TORCH_INTERNAL_ASSERT( + out->getValType().value() == ValType::TensorView || + out->getValType().value() == ValType::TensorIndex); + + TORCH_INTERNAL_ASSERT( + in_a->getValType().value() == ValType::TensorView || + in_a->getValType().value() == ValType::TensorIndex, + in_a->getValType().value()); + + TORCH_INTERNAL_ASSERT( + in_b->getValType().value() == ValType::TensorView || + in_b->getValType().value() == ValType::TensorIndex, + in_b->getValType().value()); + + addOutput(out); + addInput(in_a); + addInput(in_b); +} + +MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + out_(ir_cloner->clone(src->out_)), + in_a_(ir_cloner->clone(src->in_a_)), + in_b_(ir_cloner->clone(src->in_b_)), + init_(ir_cloner->clone(src->init_)), + options_(src->options_) {} + +bool MmaOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (auto other_mma = dynamic_cast(other)) { + return out_->sameAs(other_mma->out_) && in_a_->sameAs(other_mma->in_a_) && + in_b_->sameAs(other_mma->in_b_) && init_->sameAs(other_mma->init_) && + options_ == other_mma->options_; + } + return false; +} + ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) : Expr(src, ir_cloner), reduction_op_type_(src->reduction_op_type_), @@ -797,7 +848,9 @@ IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner) iter_type_(src->iter_type_), is_rfactor_domain_(src->is_rfactor_domain_), is_padded_dimension_(src->is_padded_dimension_), - padded_to_size_(src->padded_to_size_) {} + padded_to_size_(src->padded_to_size_), + is_instruction_(src->is_instruction_), + is_mma_swizzled_(src->is_mma_swizzled_) {} bool IterDomain::sameAs(const Statement* other) const { if (other == this) { @@ -816,6 +869,7 @@ bool IterDomain::sameAs(const Statement* other) const { is_same = is_same && ScalarCheck::sameAs(start(), other_id->start()); is_same = is_same && ScalarCheck::sameAs(stopOffset(), other_id->stopOffset()); + is_same = is_same && (is_instruction_ == other_id->is_instruction_); return is_same; } @@ -1008,6 +1062,12 @@ void IterDomain::parallelize(ParallelType t) { extent(), " ."); } + + if (isMmaSwizzled()) { + TORCH_CHECK( + t == ParallelType::Vectorize, + "Parallel type other than vectorize not allowed for warp mapped ids"); + } } bool IterDomain::maybePartial() const { @@ -1344,6 +1404,10 @@ void TensorDomain::split( "Partial split is only allowed with root domains"); } + TORCH_INTERNAL_ASSERT( + !id->isMmaSwizzled(), + "Further transformation on warp mapped id's not allowed."); + auto split_ids = IterDomain::split(id, factor, inner_split, trim_out_of_bounds); domain_.erase(domain_.begin() + axis_); @@ -1379,6 +1443,10 @@ void TensorDomain::merge(int axis_o, int axis_i) { IterDomain* first = axis(axis_o); IterDomain* second = axis(axis_i); + TORCH_INTERNAL_ASSERT( + !first->isMmaSwizzled() && !second->isMmaSwizzled(), + "Further transformation on warp mapped id's not allowed."); + IterDomain* merged_id = IterDomain::merge(first, second); domain_.erase(domain_.begin() + axis_i); diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 004cfa23dff4..b43c59207531 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -312,6 +312,29 @@ struct SubstituteInExpr : public OptInDispatch { in_N); } + void handle(MmaOp* mma_expr) final { + TORCH_INTERNAL_ASSERT( + substitute_->isA(), + "All args to MmaOp must be TensorView, but received a non-TensorView for replacement: ", + substitute_); + auto in_a = reference_->sameAs(mma_expr->inA()) + ? substitute_->as() + : mma_expr->inA(); + auto in_b = reference_->sameAs(mma_expr->inB()) + ? substitute_->as() + : mma_expr->inB(); + auto out = reference_->sameAs(mma_expr->out()) + ? substitute_->as() + : mma_expr->out(); + auto init = reference_->sameAs(mma_expr->init()) + ? substitute_->as() + : mma_expr->init(); + auto options = mma_expr->options(); + expr_ = + IrBuilder::create(mma_expr->container(), out, in_a, in_b, init); + expr_->as()->configureOptions(options); + } + private: Val* reference_ = nullptr; Val* substitute_ = nullptr; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 48774e73618f..03b7087ab64d 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -302,7 +302,7 @@ Val* ForLoop::step() const { bool ForLoop::isTrivial() const { // These loops are not materialized if (vectorize() || iter_domain()->isBroadcast() || - iter_domain()->isStride()) { + iter_domain()->isStride() || iter_domain()->isInstruction()) { return true; } diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index c03848ccff86..bb2f8b173fde 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -453,6 +453,8 @@ class AllocationInserter : public kir::ExprMutator { default_val == nullptr, "Reduction should not have a default initialization value for predicate elimination."); init = expr->as()->init(); + } else if (expr->isA()) { + init = expr->as()->init(); } else if (expr->isA()) { TORCH_INTERNAL_ASSERT( default_val == nullptr, diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index b0ef14079c43..717f1e9f0e25 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -371,6 +371,15 @@ void IndexLowering::handle(const WelfordOp* wop) { } } +void IndexLowering::handle(const MmaOp* mma) { + const auto a = lowerSrcIndex(mma->inA(), mma->out()); + const auto b = lowerSrcIndex(mma->inB(), mma->out()); + const auto out = lowerDstIndex(mma->out()); + auto mma_indexed = IrBuilder::create(out, a, b, mma->init()); + mma_indexed->configureOptions(mma->options()); + pushBack(mma_indexed); +} + void IndexLowering::handle(const BroadcastOp* bop) { TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(bop)); diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index 2f3af0061e18..cdc2be95fe2b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -35,6 +35,7 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { void handle(const TernaryOp*) final; void handle(const ReductionOp*) final; void handle(const WelfordOp*) final; + void handle(const MmaOp*) final; void handle(const BroadcastOp*) final; void handle(const kir::ForLoop*) final; @@ -45,6 +46,7 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { void generate(const std::vector& exprs); Val* lowerSrcIndex(Val* val, Val* dst) const; + Val* lowerDstIndex(Val* dst) const; private: diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index cd34c56b510e..631caae2b25c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -126,6 +126,12 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { } }; +void assertOnWarpOps(const Expr* expr) { + TORCH_INTERNAL_ASSERT( + !expr->isA(), + "Mma op: cannot eliminate predicate for mma op, tiling not valid"); +} + } // namespace std::vector generateConditionalFromPredicate( @@ -151,11 +157,17 @@ class PredicateAnalyzer : public OptOutDispatch { // of the parallelized axis is the actual size of the axis, not // the number of threads. Since the number of threads can be // larger than the axis size, it's not safe to skip predication + bool has_global_access = producer->getMemoryType() == MemoryType::Global || + consumer->getMemoryType() == MemoryType::Global; + + bool needs_sharedmem_addr_pred = false; + // Check that parallel dimension will not generate out of bound index if (!(producer->getMemoryType() == MemoryType::Local && consumer->getMemoryType() == MemoryType::Local)) { return true; } + bool needs_index_predicate = false; auto pairwise_map = PairwiseRootDomainMap(producer, consumer); auto c2p = BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) @@ -355,6 +367,10 @@ void PredicateElimination::handle(Expr* expr) { } if (needsPredicate(expr)) { + // Warp primitives are currently limited to un-predicated usage, + // predicating these ops will require extra steps to ensure that + // the whole warp will get the same value. + assertOnWarpOps(expr); return; } @@ -392,6 +408,11 @@ void PredicateElimination::handle(Expr* expr) { continue; } + if (expr->isA()) { + setReductionInitValue(input, expr->as()->init()); + continue; + } + // If an input does not need a predicate either, then it should // have some value, so no need to set a default value if (non_predicated_exprs_.find(input_def) != non_predicated_exprs_.end()) { diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 1d0096c18d62..9a13597f64d1 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -93,6 +93,7 @@ bool isTvOp(const Expr* expr) { expr->getExprType().value() == ExprType::TernaryOp || expr->getExprType().value() == ExprType::ReductionOp || expr->getExprType().value() == ExprType::WelfordOp || + expr->getExprType().value() == ExprType::MmaOp || expr->getExprType().value() == ExprType::BroadcastOp || expr->getExprType().value() == ExprType::TransposeOp || expr->getExprType().value() == ExprType::ShiftOp || @@ -455,6 +456,19 @@ class ReplaceExprInput : private kir::ExprMutator { } } + void handle(MmaOp* node) final { + auto replaced_inputs = getMaybeInputReplacementMap(node); + if (replaced_inputs.has_value()) { + auto replacement = IrBuilder::create( + node->out(), + replaced_inputs.value().at(node->inA()), + replaced_inputs.value().at(node->inB()), + node->init()); + replacement->configureOptions(node->as()->options()); + registerReplaceWithPredicate(node, replacement); + } + } + private: const std::unordered_map& replacement_map_; }; @@ -467,6 +481,33 @@ std::vector replaceInputsInExpr( return ReplaceExprInput::replace(exprs, replacement_map); } +bool hasVectorizedAccess(const kir::Allocate* alloc) { + auto buffer_tv = dynamic_cast(alloc->buffer()); + TORCH_INTERNAL_ASSERT( + buffer_tv != nullptr, "checking vectorization on non-tv allocation"); + + for (auto id : buffer_tv->domain()->domain()) { + if (id->getParallelType() == ParallelType::Vectorize) { + return true; + } + } + auto uses = buffer_tv->fusion()->unordered_uses(buffer_tv); + + for (auto use : uses) { + for (auto out : use->outputs()) { + if (auto out_tv = dynamic_cast(out)) { + for (auto id : out_tv->domain()->domain()) { + if (id->getParallelType() == ParallelType::Vectorize) { + return true; + } + } + } + } + } + + return false; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 4ed6c25e731a..a7981e724539 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -137,6 +137,8 @@ std::vector replaceInputsInExpr( const std::vector& exprs, const std::unordered_map& replacement_map); +bool hasVectorizedAccess(const kir::Allocate* alloc); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 856c757efa0e..8f4007bba0bb 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -10,6 +10,7 @@ #include #include +#include #include namespace torch { @@ -454,7 +455,9 @@ void validateVectorize(Fusion* fusion) { namespace { // Validate parallelization of a single tensor -void validateParallelizationOfTensor(TensorView* tv) { +void validateParallelizationOfTensor( + TensorView* tv, + bool is_mma_output = false) { // Each ParallelType can be used only once. ParallelTypeBitmap pt_map; for (size_t i = 0; i < tv->nDims(); ++i) { @@ -463,7 +466,32 @@ void validateParallelizationOfTensor(TensorView* tv) { if (!isParallelTypeThread(ptype)) { continue; } - + if (is_mma_output && ptype == ParallelType::TIDx) { + TORCH_INTERNAL_ASSERT( + axis->isMmaSwizzled(), + "TIDx for mma output is reserved for warp mapping", + axis, + tv); + // Check that TIDx is exact lane_id + const auto& paralel_dim_map = GpuLower::current()->parallelDimensionMap(); + TORCH_INTERNAL_ASSERT( + paralel_dim_map.isExact(ptype), + "TIDx is reserved for lane id in mma kernels, and it needs to be exactly a warp"); + TORCH_INTERNAL_ASSERT( + paralel_dim_map.get(ptype)->getInt().has_value() && + paralel_dim_map.get(ptype)->getInt().value() == + at::cuda::warp_size(), + "TIDx is reserved for lane id in mma kernels, and it needs to be exactly a warp"); + + auto maybe_dim = paralel_dim_map.get(ptype); + TORCH_INTERNAL_ASSERT(maybe_dim != nullptr); + ExpressionEvaluator const_eval(tv->fusion()); + auto maybe_dim_value = const_eval.evaluate(maybe_dim); + TORCH_INTERNAL_ASSERT( + maybe_dim_value.has_value() && + maybe_dim_value.value() == at::cuda::warp_size(), + "Mma: TIDx reserved for lane id"); + } // It doesn't matter if this axis is a non-concretized broadcast // TODO: merging broadcast and non-broadcast if (axis->isBroadcast() && @@ -518,9 +546,10 @@ void validateParallelize(Fusion* fusion) { if (!ir_utils::isTvOp(expr)) { continue; } + bool is_mma = expr->getExprType().value() == ExprType::MmaOp; // Validate parallelization of each consumer by itself for (auto consumer : ir_utils::filterByType(expr->outputs())) { - validateParallelizationOfTensor(consumer); + validateParallelizationOfTensor(consumer, is_mma); } // Validate parallelization between a producer and a consumer for (auto producer : ir_utils::filterByType(expr->inputs())) { @@ -544,6 +573,14 @@ void validateParallelize(Fusion* fusion) { if (!isParallelTypeThread(producer_ptype)) { continue; } + if (is_mma && producer_ptype == ParallelType::TIDx) { + TORCH_INTERNAL_ASSERT( + producer_axis->isMmaSwizzled(), + "mma input: use WarpMmaMapper to schedule warp input"); + // Warp mapped ids will not map across mma op since they have + // different swizzle formats. + continue; + } // When the producer axis is a broadcast, it is not really // parallelized unless thread-predicated if (producer_axis->isBroadcast() && diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp new file mode 100644 index 000000000000..58280c47fc36 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -0,0 +1,137 @@ +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +MmaBuilder::MmaBuilder(MmaOptions::MacroType macro, GemmTileOptions gemm_tile) { + option_.macro = macro; + // Calculate accumulator stride, will be removed once transpose swizzle ready + int outer_stride = gemm_tile.warp_tile.n / gemm_tile.instruction_tile.n; + switch (macro) { + // Numbers depend on actual output layout of mma instruction + case MmaOptions::MacroType::Volta_16_16_4: + option_.accumulator_stride = outer_stride * 4; + break; + default: + TORCH_CHECK(false, "unsupported macro"); + break; + } +} + +MmaBuilder& MmaBuilder::layout(MmaOptions::MmaInputLayout layout) { + option_.operand_layout = layout; + return *this; +} + +MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) { + option_.operand = a_or_b; + return *this; +} + +// TODO: validate op config +MmaOptions MmaBuilder::build() const { + return option_; +} + +bool isVolta(MmaOptions::MacroType macro) { + return macro == MmaOptions::MacroType::Volta_16_16_4; +} + +bool isTuring(MmaOptions::MacroType macro) { + return macro == MmaOptions::MacroType::Turing_16_8_16; +} + +bool isAmpere(MmaOptions::MacroType macro) { + return false; +} + +int getOutputRegisterSize(MmaOptions::MacroType macro) { + switch (macro) { + case MmaOptions::MacroType::Volta_16_16_4: + return 8; + break; + default: + TORCH_INTERNAL_ASSERT(false, "unknown macro"); + break; + } + return -1; +} + +int getInputARegisterSize(MmaOptions::MacroType macro) { + switch (macro) { + case MmaOptions::MacroType::Volta_16_16_4: + return 4; + break; + default: + TORCH_INTERNAL_ASSERT(false, "unknown macro"); + break; + } + return -1; +} + +int getInputBRegisterSize(MmaOptions::MacroType macro) { + switch (macro) { + case MmaOptions::MacroType::Volta_16_16_4: + return 4; + break; + default: + TORCH_INTERNAL_ASSERT(false, "unknown macro"); + break; + } + return -1; +} + +bool isOperandTransposed(MmaOptions options) { + switch (options.operand) { + case MmaOptions::Operand::A: + return options.operand_layout == MmaOptions::MmaInputLayout::TT || + options.operand_layout == MmaOptions::MmaInputLayout::TN; + case MmaOptions::Operand::B: + return options.operand_layout == MmaOptions::MmaInputLayout::TT || + options.operand_layout == MmaOptions::MmaInputLayout::NT; + default: + TORCH_CHECK(false, "isOperandTransposed: please specify operand"); + } + return false; +} + +std::string toString(MmaOptions::MmaInputLayout input_layout) { + std::stringstream ss; + switch (input_layout) { + case MmaOptions::MmaInputLayout::TT: + ss << "TT"; + break; + case MmaOptions::MmaInputLayout::TN: + ss << "TN"; + break; + case MmaOptions::MmaInputLayout::NT: + ss << "NT"; + break; + default: + TORCH_INTERNAL_ASSERT(false, "unsupported operand layout"); + } + return ss.str(); +} + +std::string toString(MmaOptions::MacroType mt) { + std::stringstream ss; + switch (mt) { + case MmaOptions::MacroType::NoMMA: + ss << "NoOp"; + break; + case MmaOptions::MacroType::Volta_16_16_4: + ss << "M16N16K4"; + break; + default: + TORCH_INTERNAL_ASSERT(false, "undefined mma type"); + break; + } + return ss.str(); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h new file mode 100644 index 000000000000..407ca3c97725 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -0,0 +1,132 @@ +#pragma once +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +//! Utility data structure for recording gemm tiles +struct GemmTile { + int m, n, k; + GemmTile(int m_, int n_, int k_) : m(m_), n(n_), k(k_) {} + + bool operator==(const GemmTile& other) { + return m == other.m && n == other.n && k == other.k; + } + + GemmTile operator/(const GemmTile& other) { + return GemmTile(m / other.m, n / other.n, k / other.k); + } +}; + +//! Utility data structure for recording gemm tiles +struct TORCH_CUDA_CU_API GemmTileOptions { + GemmTile cta_tile = GemmTile(128, 128, 32); + GemmTile warp_tile = GemmTile(64, 64, 32); + GemmTile instruction_tile = GemmTile(16, 8, 16); + + GemmTileOptions() = default; + GemmTileOptions( + GemmTile cta_tile_, + GemmTile warp_tile_, + GemmTile instruction_tile_) + : cta_tile(cta_tile_), + warp_tile(warp_tile_), + instruction_tile(instruction_tile_) {} + + bool operator==(const GemmTileOptions& other) { + return cta_tile == other.cta_tile && warp_tile == other.warp_tile && + instruction_tile == other.instruction_tile; + } +}; + +//! Information for configuring and lowering mma ops +struct MmaOptions { + //! Type of mma instrinsic macro to use + //! This will translate to which mma intrinsic from runtime string + //! to be generated to implement the mma op. The current plan + //! is to have exactly one macro for each + //! (arch, datatype, operand layout) triple, though there + //! exists multiple possibilities for some cases, e.g. for Turing and fp16 + //! one can use 16_8_8 or 16_8_16. + //! Will consider adding more choices that the scheduler can pick from + //! when our perf target becomes more fine grained, which is more likely in + //! latency bound kernels. + enum class MacroType { + NoMMA = 0, + Volta_16_16_4, + Turing_16_8_16, // place holder for turing/ampere mma + Ampere_16_8_8 // place holder for tf32 + }; + + //! [Operand Layout Convention] + //! Operand layout, T=transposed/row_major, N=normal/col_major + //! We don't support calling NN mma directly since it implies + //! a fused transpose. User needs to swap the operands and use + //! TT mma to make the transpose explicit. + //! Ordered by position of K + //! NT : K,M x K,N -> K,M,N + //! TT : M,K X K,N -> M,K,N + //! TN : M,K X N,K -> M,N,K + enum class MmaInputLayout { NT = 0, TT, TN }; + + //! Utility to annotate which input of mma this option struct describes + enum class Operand { NotOperand = 0, A, B }; + + //! Utility to annotate which mma macro this config uses. + MacroType macro = MacroType::NoMMA; + + //! Utility to annotate transposition of operands + MmaInputLayout operand_layout = MmaInputLayout::TT; + + //! Utility to annotate which input of mma this option struct describes + Operand operand = Operand::A; + + //! Accumulator register stride, will be removed when the swizzle op + //! is introduced and the output can be labeled with a transpose swizzle. + int accumulator_stride = 0; + + bool operator==(const MmaOptions& other) const { + return macro == other.macro && operand_layout == other.operand_layout && + operand == other.operand && + accumulator_stride == other.accumulator_stride; + } +}; + +//! User interface generating mma options for mma op +class TORCH_CUDA_CU_API MmaBuilder { + public: + MmaBuilder(MmaOptions::MacroType macro, GemmTileOptions gemm_tile); + MmaBuilder& layout(MmaOptions::MmaInputLayout layout); + MmaBuilder& operand(MmaOptions::Operand a_or_b); + MmaOptions build() const; + + private: + MmaOptions option_; +}; + +//! GPU arch check for macro type +bool isVolta(MmaOptions::MacroType macro); +bool isTuring(MmaOptions::MacroType macro); +bool isAmpere(MmaOptions::MacroType macro); + +//! Returns true if the given option describes a transposed operand +bool isOperandTransposed(MmaOptions options); + +// Unpacked constants from macro type: +// exact numbers are defined by each individual instruction. +int getOutputRegisterSize(MmaOptions::MacroType macro); +int getInputARegisterSize(MmaOptions::MacroType macro); +int getInputBRegisterSize(MmaOptions::MacroType macro); + +// MMA stringify utils +std::string toString(MmaOptions::MacroType macro); +std::string toString(MmaOptions::MmaInputLayout input_layout); +std::string toString(MmaOptions::MacroType mt); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 894455da0ded..0f240c3ad6d4 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -235,6 +235,24 @@ void OptOutMutator::mutate(WelfordOp* wop) { in_N); } +void OptOutMutator::mutate(MmaOp* mma) { + Val* out = maybeMutated(mma->out()); + Val* in_a = maybeMutated(mma->inA()); + Val* in_b = maybeMutated(mma->inB()); + Val* init = mma->init(); + + if (out->sameAs(mma->out()) && in_a->sameAs(mma->inA()) && + in_b->sameAs(mma->inB())) { + return; + } + + auto container = mma->container(); + auto options = mma->options(); + container->removeExpr(mma); + auto new_mma = IrBuilder::create(container, out, in_a, in_b, init); + new_mma->configureOptions(options); +} + void OptOutMutator::mutate(BroadcastOp* bop) { Val* out = maybeMutated(bop->out()); Val* in = maybeMutated(bop->in()); diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp index da521bd855f0..fc5f517b4d67 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp @@ -285,6 +285,12 @@ void UnmappableReductionDomains::handle(ReductionOp* op) { handleReductionOutput(out_tv); } +void UnmappableReductionDomains::handle(MmaOp* mma) { + // Builds a map from reduction domains to consumer domains. + TensorView* out_tv = mma->out()->as(); + handleReductionOutput(out_tv); +} + void UnmappableReductionDomains::handle(WelfordOp* op) { // Builds a map from reduction domains to consumer domains. handleReductionOutput(op->outAvg()->as()); diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.h b/torch/csrc/jit/codegen/cuda/root_domain_map.h index 366801f4ceea..4d7ccfb3dec3 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.h +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.h @@ -187,6 +187,7 @@ class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor { using IterVisitor::handle; void handle(ReductionOp* op) override; void handle(WelfordOp* op) override; + void handle(MmaOp* op) override; void handleReductionOutput(TensorView* out_tv); @@ -393,6 +394,10 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder mapPointwiseOrReductionOp(wop); } + void handle(MmaOp* wop) override { + mapPointwiseOrReductionOp(wop); + } + void handle(ShiftOp* op) override { mapPointwiseOrReductionOp(op); } diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu new file mode 100644 index 000000000000..f95978e84475 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu @@ -0,0 +1,215 @@ +// Utility macro for this file +#define DEVICE_INLINE __device__ inline + +// MMA instruction wrappers: +// The wrappers are subroutines that implement matrix of size +// A(M,K) X B(K,N) = C(M,N) +// The naming of the wrappers follow similar naming conventions +// as the mma instructions. +// All the mma macros follow the namespace and naming like +// Arch::M (M-dim) N (N-dim) K(K-dim) (Layout), eg. +// Volta::M16N16K4TT, +// with the dimensions describing the size of the sub-matrices being +// multiplied by this wrapper. +// see [Operand Layout Convention] in mma_type.h for details on the layout +// notation. +namespace Volta { + +namespace util { +// MMA instruction wrappers (sm_70+): +// The instruction wrappers below are quarter-warp macros, which currently +// nvfuser +// doesn't explicitly model. So they are currently only meant to be +// used as building blocks in warp level mma macros + +// 8x8x4 mma instruction, per quarter warp (8 threads), fp32 accumulate +// per thread register: +// A[4] x B[4] -> C[8] +DEVICE_INLINE void mmaM8n8k4tt( + Array* C, + Array<__half, 4, 4>* A, + Array<__half, 4, 4>* B) { + unsigned const* _A = reinterpret_cast(A); + unsigned const* _B = reinterpret_cast(B); + unsigned* _C = reinterpret_cast(C); + + asm("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=r"(_C[0]), + "=r"(_C[1]), + "=r"(_C[2]), + "=r"(_C[3]), + "=r"(_C[4]), + "=r"(_C[5]), + "=r"(_C[6]), + "=r"(_C[7]) + : "r"(_A[0]), + "r"(_A[1]), + "r"(_B[0]), + "r"(_B[1]), + "r"(_C[0]), + "r"(_C[1]), + "r"(_C[2]), + "r"(_C[3]), + "r"(_C[4]), + "r"(_C[5]), + "r"(_C[6]), + "r"(_C[7])); +} + +DEVICE_INLINE void mmaM8n8k4tn( + Array* C, + Array<__half, 4, 4>* A, + Array<__half, 4, 4>* B) { + unsigned const* _A = reinterpret_cast(A); + unsigned const* _B = reinterpret_cast(B); + unsigned* _C = reinterpret_cast(C); + + asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=r"(_C[0]), + "=r"(_C[1]), + "=r"(_C[2]), + "=r"(_C[3]), + "=r"(_C[4]), + "=r"(_C[5]), + "=r"(_C[6]), + "=r"(_C[7]) + : "r"(_A[0]), + "r"(_A[1]), + "r"(_B[0]), + "r"(_B[1]), + "r"(_C[0]), + "r"(_C[1]), + "r"(_C[2]), + "r"(_C[3]), + "r"(_C[4]), + "r"(_C[5]), + "r"(_C[6]), + "r"(_C[7])); +} + +DEVICE_INLINE void mmaM8n8k4nt( + Array* C, + Array<__half, 4, 4>* A, + Array<__half, 4, 4>* B) { + unsigned const* _A = reinterpret_cast(A); + unsigned const* _B = reinterpret_cast(B); + unsigned* _C = reinterpret_cast(C); + + asm("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=r"(_C[0]), + "=r"(_C[1]), + "=r"(_C[2]), + "=r"(_C[3]), + "=r"(_C[4]), + "=r"(_C[5]), + "=r"(_C[6]), + "=r"(_C[7]) + : "r"(_A[0]), + "r"(_A[1]), + "r"(_B[0]), + "r"(_B[1]), + "r"(_C[0]), + "r"(_C[1]), + "r"(_C[2]), + "r"(_C[3]), + "r"(_C[4]), + "r"(_C[5]), + "r"(_C[6]), + "r"(_C[7])); +} + +// TODO: in a follow up, +// lift this part onto iterdomain ops, once the +// swizzle ops are ready. +template +DEVICE_INLINE Array accToMma(float* _C) { + float C_data[8] = { + _C[0], + _C[1], + _C[acc_stride], + _C[acc_stride + 1], + _C[2], + _C[3], + _C[acc_stride + 2], + _C[acc_stride + 3], + }; + + return *reinterpret_cast*>(&C_data[0]); +} + +template +DEVICE_INLINE void mmaToAcc(float* _C, Array& C) { + float* C_data = reinterpret_cast(&C); + _C[0] = C_data[0]; + _C[1] = C_data[1]; + _C[acc_stride] = C_data[2]; + _C[acc_stride + 1] = C_data[3]; + _C[2] = C_data[4]; + _C[3] = C_data[5]; + _C[acc_stride + 2] = C_data[6]; + _C[acc_stride + 3] = C_data[7]; +} + +// Should be able to lift this with transpose op as well. +template +DEVICE_INLINE void initM16N16K4(Array& accumulator) { + float* _C = reinterpret_cast(&accumulator); + float zeros[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + mmaToAcc(_C, *reinterpret_cast*>(&zeros[0])); +} + +} // namespace util + +template +DEVICE_INLINE void M16N16K4TT( + Array* C, + Array<__half, 4, 4>* A, + Array<__half, 4, 4>* B) { + float* _C = reinterpret_cast(C); + Array C_data = util::accToMma(_C); + util::mmaM8n8k4tt(&C_data, A, B); + util::mmaToAcc(_C, C_data); +} + +template +DEVICE_INLINE void M16N16K4TN( + Array* C, + Array<__half, 4, 4>* A, + Array<__half, 4, 4>* B) { + float* _C = reinterpret_cast(C); + Array C_data = util::accToMma(_C); + util::mmaM8n8k4tn(&C_data, A, B); + util::mmaToAcc(_C, C_data); +} + +template +DEVICE_INLINE void M16N16K4NT( + Array* C, + Array<__half, 4, 4>* A, + Array<__half, 4, 4>* B) { + float* _C = reinterpret_cast(C); + Array C_data = util::accToMma(_C); + util::mmaM8n8k4nt(&C_data, A, B); + util::mmaToAcc(_C, C_data); +} + +// Same initialization for now, will be different in interleaved +// macros +template +DEVICE_INLINE void initM16N16K4TT(Array* accumulator) { + util::initM16N16K4(*accumulator); +} + +template +DEVICE_INLINE void initM16N16K4TN(Array* accumulator) { + util::initM16N16K4(*accumulator); +} + +template +DEVICE_INLINE void initM16N16K4NT(Array* accumulator) { + util::initM16N16K4(*accumulator); +} + +} // namespace Volta + +#undef DEVICE_INLINE diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp new file mode 100644 index 000000000000..8d3f74852edd --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -0,0 +1,420 @@ + +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace mma_util { + +namespace { + +// Utility for mma dimension matching +enum class MmaDimension { M = 0, N, K }; + +// Utility for mma dimension matching, assumes the innermost +// 3 dimensions are the mma operand dimensions, i.e. mnk, but +// not necessarily in this order. +// For matmul use cases the root domains are always 3 dimensional, +// but this wouldn't be the case for other kernels such as batched gemm. +// This utility only applies to the case where the innermost 3 dims +// are the one that mma's are used. We probably don't want to use +// mma intrinsics if that's not the case. +IterDomain* getMmaOperandRootDimension3d( + TensorView* tv, + MmaOptions::MmaInputLayout layout, + MmaDimension mma_dimension) { + TORCH_INTERNAL_ASSERT(tv->getMaybeRFactorDomain().size() >= 3); + // NT : K,M x K,N -> K,M,N + // TT : M,K X K,N -> M,K,N + // TN : M,K X N,K -> M,N,K + int axis_id = -1; + switch (mma_dimension) { + case MmaDimension::K: + axis_id = (int)layout; + break; + case MmaDimension::M: + axis_id = layout == MmaOptions::MmaInputLayout::NT ? 1 : 0; + break; + case MmaDimension::N: + axis_id = layout == MmaOptions::MmaInputLayout::TN ? 1 : 2; + break; + default: + TORCH_INTERNAL_ASSERT(false, "Unreachable"); + break; + } + + int root_size = tv->getMaybeRFactorDomain().size(); + // Convert to index from right. + return tv->getMaybeRFactorDomain().at(root_size + axis_id - 3); +} + +// Locate the root id corresponding to the given mma dimension +// Assumes the mma dimension always the innermost 2 or 3, might +// need to extend for more complex fusions. +IterDomain* getMmaOperandRootDimension( + TensorView* tv, + MmaOptions options, + MmaDimension mma_dimension) { + if (isVolta(options.macro)) { + return getMmaOperandRootDimension3d( + tv, options.operand_layout, mma_dimension); + } + TORCH_INTERNAL_ASSERT(false, "unreachable"); + return nullptr; +} + +// Preliminary checks to try to validate that leaf is +// a innermost dim of root of exactly the given size. +bool canValidateIsInnerDim( + IterDomain* root, + IterDomain* leaf, + int inner_dim_size) { + // Accept boundary case for Volta. + if (leaf == root && leaf->isBroadcast()) { + return true; + } + auto expr = leaf->definition(); + ExpressionEvaluator const_eval(leaf->fusion()); + auto maybe_leaf_size = const_eval.evaluate(leaf->extent()); + if (!maybe_leaf_size.has_value()) { + return false; + } + if (maybe_leaf_size.value() != inner_dim_size) { + return false; + } + + while (expr) { + if (auto split = dynamic_cast(expr)) { + // Inner split only + if (leaf != split->inner()) { + return false; + } + // Const split only + auto maybe_factor = const_eval.evaluate(split->factor()); + if (!maybe_factor.has_value()) { + return false; + } + int factor = maybe_factor.value(); + if (factor < inner_dim_size) { + // This might be too restrictive. Would need more + // bookkeeping to relax. + return false; + } + leaf = split->in(); + } else if (auto merge = dynamic_cast(expr)) { + // Might consider just rejecting merge. + auto outer = merge->outer(); + auto inner = merge->inner(); + if (outer->isBroadcast()) { + return false; + } + + // Only support merging with constant sized dims + maybe_leaf_size = const_eval.evaluate(leaf->extent()); + if (!maybe_leaf_size.has_value()) { + return false; + } + if (maybe_leaf_size.value() != inner_dim_size) { + return false; + } + leaf = merge->inner(); + } else { + // No support for swizzled inner dim for now. + // Might need to add transpose swizzle here. + return false; + } + expr = leaf->definition(); + } + return leaf == root; +} + +} // namespace + +void checkDimSize( + TensorView* tv, + std::vector axis, + std::vector expect) { + ExpressionEvaluator const_eval(tv->fusion()); + for (auto axis_index : c10::irange(axis.size())) { + auto id = tv->axis(axis[axis_index]); + auto maybe_extent = const_eval.evaluate(id->extent()); + TORCH_CHECK( + maybe_extent.has_value(), + "Mma warp mapping: instruction tile has to be constant"); + TORCH_CHECK( + maybe_extent.value() == expect[axis_index], + "Mma warp mapping: unexpected tile size at", + axis_index, + ":", + maybe_extent.value(), + "vs", + expect[axis_index]); + } +} + +void WarpMmaSwizzler::scheduleMmaWarpOutput( + TensorView* tv, + MmaOptions options) { + auto macro = options.macro; + switch (macro) { + case MmaOptions::MacroType::Volta_16_16_4: + scheduleVoltaM16N16K4Fp32Output(tv, options); + if (tv->definition()->isA()) { + setWarpMapped(tv, 5); + } + break; + default: + TORCH_CHECK( + false, "scheduleMmaWarp: unsupported mma option ", toString(macro)); + break; + } +} + +void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) { + // Schedules operand for inner most 3 contiguous dimensions + // Assumes M, N, K + + switch (options.macro) { + case MmaOptions::MacroType::Volta_16_16_4: + scheduleVoltaOperandRead(tv, options); + break; + default: + TORCH_CHECK(false, "WarpMmaSwizzler: please specify macro"); + break; + } +} + +void WarpMmaSwizzler::setWarpMapped(TensorView* tv, int number_of_dims) { + for (int id : c10::irange(number_of_dims)) { + tv->axis(-id - 1)->toMmaSwizzled(); + } +} + +namespace { + +// Utility to check operand innermost scheduled dimensions +void validateInnerMNK(TensorView* tv, MmaOptions options, int m, int n, int k) { + TORCH_INTERNAL_ASSERT(tv->nDims() >= 3); + TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::M), + tv->axis(-3), + m)); + TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::N), + tv->axis(-2), + n)); + TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::K), + tv->axis(-1), + k)); +} + +void validateResultInnerMN(TensorView* tv, int m, int n) { + TORCH_INTERNAL_ASSERT(tv->nDims() >= 2); + int root_dim = tv->getMaybeRFactorDomain().size(); + TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( + tv->getMaybeRFactorDomain()[root_dim - 2], tv->axis(-2), m)); + TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( + tv->getMaybeRFactorDomain()[root_dim - 1], tv->axis(-1), m)); +} + +void scheduleVoltaA(TensorView* tv, MmaOptions options) { + // Assumed: + // [..., 16, 16 ,4] + // [..., M, BN, K] + // Some validation: + validateInnerMNK(tv, options, 16, 16, 4); + bool transposed = isOperandTransposed(options); + + tv->split(-3, 4); + + // Split out 16 from the bcast + tv->split(-2, 16); + tv->split(-2, 8); + + // -6 -5 -4 -3 -2 -1 + //[Mo4, Mi4, Noo, No2, Ni8, K] + + if (transposed) { + tv->reorder({{-5, -3}, {-3, -5}}); + // -6 -5 -4 -3 -2 -1 + //[Mo4, No2, Noo, Mi4, Ni8, K] + + } else { + tv->reorder({{-5, -1}, {-3, -5}, {-1, -3}}); + // -6 -5 -4 -3 -2 -1 + //[Mo4, No2, Noo, K, Ni8, Mi4] + } + + tv->merge(-6); + tv->merge(-5); + tv->merge(-4); + + //[Warp, Ni8, K/Mi4] + tv->axis(-3)->parallelize(ParallelType::TIDx); +} + +void scheduleVoltaB(TensorView* tv, MmaOptions options) { + // Assumed: + // [..., 16,16,4] + // [..., BM, N, K] + // Some validation: + validateInnerMNK(tv, options, 16, 16, 4); + + bool transposed = isOperandTransposed(options); + tv->split(-3, 16); + tv->split(-3, 8); + + tv->split(-2, 8); + tv->split(-2, 4); + + // -7 -6 -5 -4 -3 -2 -1 + //[Moo, Mo2, Mi8, No2, Nio2, Nii4, K] + tv->reorder({{-6, -4}, {-5, -6}, {-4, -3}, {-3, -5}}); + + // -7 -6 -5 -4 -3 -2 -1 + //[Moo, Mi8, Nio2, Mo2, No2, Nii4, K ] + if (transposed) { + tv->reorder({{-2, -1}, {-1, -2}}); + // -7 -6 -5 -4 -3 -2 -1 + //[Moo, Mi8, Nio2, Mo2, No2, K, Nii4] + } + + tv->merge(-5); + tv->merge(-4); + tv->merge(-3); + + //[Moo, Mi8, Warp, K/Nii4] + tv->axis(-2)->parallelize(ParallelType::TIDx); +} + +} // namespace + +void WarpMmaSwizzler::scheduleVoltaOperandRead( + TensorView* tv, + MmaOptions options) { + switch (options.operand) { + case MmaOptions::Operand::A: + scheduleVoltaA(tv, options); + setWarpMapped(tv, 3); + break; + case MmaOptions::Operand::B: + scheduleVoltaB(tv, options); + setWarpMapped(tv, 4); + break; + default: + TORCH_CHECK(false, "WarpMmaSwizzler: please specify operand"); + } +} + +// Fp32 and Fp16 outputs have different layouts on volta, +// but we only support fp32 accumulate at this stage. +void WarpMmaSwizzler::scheduleVoltaM16N16K4Fp32Output( + TensorView* tv, + const MmaOptions& options) { + // Assume last 2 dims [M16, N16] or [M16, N16, R] + bool is_reduction = tv->axis(-1)->isReduction(); + + // Make sure instruction tile size is correct. + if (is_reduction) { + validateInnerMNK(tv, options, 16, 16, 4); + } else { + validateResultInnerMN(tv, 16, 16); + } + + int m_pos = is_reduction ? -3 : -2; + + // Assumed: + // m + // [..., 16,16, (4)] + // [..., M, N, (R)] + tv->split(m_pos, 4); + tv->split(m_pos, 2); + tv->split(m_pos + 1, 8); + tv->split(m_pos + 1, 4); + tv->split(m_pos + 1, 2); + + // m-5 m-4 m-3 m-2 m-1 m m+1 m+2 + // [..., Mo4, Mio2, Mii2, No2, Nio2, Niio2, Niii2, (R)] + tv->reorder( + {{m_pos - 4, m_pos - 1}, + {m_pos - 3, m_pos - 2}, + {m_pos - 2, m_pos - 4}, + {m_pos - 1, m_pos}, + {m_pos, m_pos - 3}}); + + // m-5 m-4 m-3 m-2 m-1 m m+1 m+2 + // [..., Mo4, No2, Niio2, Mii2, Mio2, Nio2, Niii2, (R)] + + tv->merge(m_pos - 5); + tv->merge(m_pos - 4); + tv->merge(m_pos - 3); + + // m-2 m-1 m m+1 m+2 + //[Warp, Mio2, Nio2, Niii2, (R)] + tv->axis(m_pos - 2)->parallelize(ParallelType::TIDx); + + if (is_reduction && tv->definition()->isA()) { + // Set instruction loops for mma reduce output + for (int pos : c10::irange(5)) { + tv->axis(-pos - 1)->toInstruction(); + tv->axis(-pos - 1)->toMmaSwizzled(); + } + } +} + +namespace { + +bool isMmaInitLoop(const kir::Scope& loop_body) { + for (auto expr : loop_body.exprs()) { + if (auto inner_loop = dynamic_cast(expr)) { + if (!isMmaInitLoop(inner_loop->body())) { + return false; + } + } else if (auto uop = dynamic_cast(expr)) { + if (!ir_utils::isTvOp(expr) || + uop->getUnaryOpType() != UnaryOpType::Set) { + return false; + } + if (auto ti = dynamic_cast(expr->output(0))) { + if (!ti->view()->definition() || + !ti->view()->definition()->isA()) { + return false; + } + } + if (auto tv = dynamic_cast(expr->output(0))) { + if (!tv->definition() || !tv->definition()->isA()) { + return false; + } + } + } else if (auto ite = dynamic_cast(expr)) { + if (!isMmaInitLoop(ite->thenBody())) { + return false; + } + if (!isMmaInitLoop(ite->elseBody())) { + return false; + } + } else { + return false; + } + } + return true; +} + +} // namespace + +bool isMmaInitLoop(const kir::ForLoop* loop) { + return isMmaInitLoop(loop->body()); +} + +} // namespace mma_util + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h new file mode 100644 index 000000000000..f98c63961ccd --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h @@ -0,0 +1,144 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace mma_util { + +//! [WarpMmaSwizzler]: +//! This class is used to implement the thread swizzle format +//! required for the mma macros, cf. PTX ISA 9.7.13.4. +//! +//! The mma instructions (Volta through Ampere) require specific +//! thread mapping within a warp for both the mma inputs and +//! mma outputs. All mma swizzle patterns seen so far turned out +//! to be affine, so we could use the normal scheduler interface +//! to fulfill the mma thread swizzle pattern. And fusion with +//! other non-mma ops and validations can just natually rely on the current +//! iterdomain infrastructure. +//! +//! This is different from a normal scheduler utility though, +//! as the thread mapping within a warp are *required* to be +//! a specific pattern which currently translates to an enforced +//! requirement that all the leaf domains produced by WarpMmaSwizzler +//! cannot be further transformed (split/merge/reorder etc.). +//! +//! Currently WarpMmaSwizzler can be accessed by schedulers through +//! TensorView::applyMmaSwizzle, and the current scheduling procedure is +//! as follows: +//! +//! Step 1. Before scheduling, the mma op needs to be configured with a macro +//! type, +//! either manually or inferred (eg. Volta_16_16_4). +//! +//! Step 2. Scheduler can tile the outer dimensions based on any heuristics, +//! i.e. +//! the CTA tiling, warp tiling, splitK etc. +//! +//! Step 3. The scheduler will need to split the innermost part of the 3 +//! involved +//! root dimensions, they need to be ordered as M,N,K on the rightmost of +//! tensordomain (see [Operand Layout Convention] for exact definition). +//! +//! For example before calling WarpMmaSwizzler, the domain could look like: +//! [TileM, TileN, TileK, Im(16), In(16), Rk(4)], to use Volta_16_16_4. +//! The rightmost 3 iterdomains need to be the innermost component of their +//! corresponding root id, similar to vectorization except this requirement +//! applies to all 3 rightmost dims. +//! +//! Before applying swizzle, WarpMmaSwizzler will try to validate: +//! 1. The "innermost-ness" of the rightmost 3 iterdomains. E.g: +//! Xo, Xi = split(X, 16), +//! Xo doesn't check, Xi would check. +//! 2. The rightmost three are constant sized, and they are ordered as +//! M,N,K. +//! In the case of operand schedule before the broadcast, only 2 of +//! the axis are see, and they still need to follow the same order, +//! i.e. need to be M,K or N,K. +//! 3. The rightmost three axes have matching size with the selected +//! mma macro. +//! +//! Step 4. WarpMmaSwizzler will transform the rightmost 3 domains to the +//! correct swizzle +//! format and will parallelize the TIDx, which is reserved for lane id. The +//! transformed inner iterdomains will be locked with WarpMapped tag so that +//! they cannot be further transformed. Currently the only change that +//! scheduler can still do after this step is to vectorize the innermost +//! iterdomain. +//! +//! Notes: +//! This version of implementation is trying to balance the composition +//! flexibility and validation complexity. Currently the validation protocol +//! is that if the rightmost 3 dimensions given to WarpMmaSwizzler are indeed +//! innermost components of the 3 root id's and their dimensions match the mma +//! macro, the swizzle format produced by WarpMmaSwizzler will be correct for +//! the macro and we just lock the innermost iterdomains from further +//! transformations. +//! +//! Ninja users/schedulers might go for 2 cases that we currently don't +//! support: +//! +//! 1. Equivalent affine transforms: +//! Even though the mma swizzles are affine, there are still infinitely many +//! equivalent ways to implement +//! the same affine transform. E.g. io,ii = split(i,8); ioii = +//! merge(io,ii); would make ioii equiv to i if it's a divisible split. One +//! can use this to construct infinite many equivalent affine swizzles. +//! +//! Users/schedulers might want to have a different but equivalent affine +//! representation from the one provided +//! by WarpMmaSwizzler, but validating them needs some extra work +//! canonicalizing the affine transforms. So short term wouldn't support +//! this flexibility. +//! +//! 2. Swizzled data input: +//! It is also possible that the data input has other swizzles before +//! entering the fusion already and some +//! might be natively compatible with mma format. This is a very broad +//! category of use cases and we'd have to consider enabling any use like +//! this case-by-case. +class TORCH_CUDA_CU_API WarpMmaSwizzler { + public: + //! Applies the output mma swizzling to the given tv, should be used + //! on mma output or tv's involved in epilog fusion, i.e. bias. + //! The rightmost iterdomains must follow the m,n,k convention before calling. + static void scheduleMmaWarpOutput(TensorView* tv, MmaOptions options); + + //! Applies the input mma swizzling to the given tv, should be used + //! on mma input or tv's involved in any fusion before mma, but after smem + //! read. + //! The rightmost iterdomains must follow the m,n,k convention before calling. + static void scheduleOperandRead( + TensorView* tv, + MmaOptions options = MmaOptions()); + + private: + //! Swizzle implementations for Volta mma. + static void scheduleVoltaOperandRead(TensorView* tv, MmaOptions options); + static void scheduleVoltaM16N16K4Fp32Output( + TensorView* tv, + const MmaOptions& options); + + //! Utility to lock the transformed dimensions from further transforms. + static void setWarpMapped(TensorView* tv, int number_of_dims); +}; + +void checkDimSize( + TensorView* tv, + std::vector axis, + std::vector expect); + +// Returns if the loopnest is initializing for an mma op. +bool isMmaInitLoop(const kir::ForLoop* loop); + +} // namespace mma_util + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index 714b712c2303..6be84459fa85 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace torch { @@ -1393,6 +1394,97 @@ std::vector getBroadcastMultiples(TensorView* reference_tv) { return multiples; } +namespace matmul_utils { + +void scheduleWarpTileWithReduction(TensorView* tv, GemmTileOptions tile) { + // Assumes + // [M, N, K] + auto cta_tile = tile.cta_tile; + auto warp_tile = tile.warp_tile; + auto instruction_tile = tile.instruction_tile; + + TORCH_CHECK( + warp_tile.k == cta_tile.k, + "schedule warp tile: currently no support for splitting k dimension to different warps"); + + mma_util::checkDimSize( + tv, {-3, -2, -1}, {cta_tile.m, cta_tile.n, cta_tile.k}); + + // -3 -2 -1 + //[... M, N, K] + + // Distribute warp tile: + tv->split(-3, warp_tile.m); + tv->split(-2, warp_tile.n); + + // -5 -4 -3 -2 -1 + // [Mwo Mw Nwo Nw K] + tv->split(-4, instruction_tile.m); + tv->split(-2, instruction_tile.n); + tv->split(-1, instruction_tile.k); + + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Mw Mi Nwo Nw Ni Ko Ki] + + tv->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}}); + + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Nwo Ko Mw Nw Mi Ni Ki] +} + +void scheduleWarpTileWithNoReduction(TensorView* tv, GemmTileOptions tile) { + // Assumes + // [M, N, K] + auto cta_tile = tile.cta_tile; + auto warp_tile = tile.warp_tile; + auto instruction_tile = tile.instruction_tile; + + mma_util::checkDimSize(tv, {-2, -1}, {cta_tile.m, cta_tile.n}); + + // -2 -1 + //[... M, N] + + // Distribute warp tile: + tv->split(-2, warp_tile.m); + tv->split(-1, warp_tile.n); + + // -4 -3 -2 -1 + // [Mwo Mw Nwo Nw ] + tv->split(-3, instruction_tile.m); + tv->split(-1, instruction_tile.n); + + // -6 -5 -4 -3 -2 -1 + // [Mwo Mw Mi Nwo Nw Ni] + + tv->reorder({{-5, -4}, {-4, -2}, {-3, -5}, {-2, -3}}); + + // -6 -5 -4 -3 -2 -1 + // [Mwo Nwo Mw Nw Mi Ni] +} + +//! Split the innermost dim to a vectorized load +void scheduleContiguousVectorLoad( + TensorView* tv, + GemmTileOptions tile, + int vector_word) { + auto warp_dims = tile.cta_tile / tile.warp_tile; + int num_of_thread = warp_dims.m * warp_dims.n * warp_dims.k * 32; + + tv->split(-1, num_of_thread * vector_word); + tv->split(-1, vector_word); + // [..., thread, vec] + // distribute to warp: + tv->split(-2, 32); + tv->split(-3, warp_dims.n * warp_dims.k); + + tv->axis(-1)->parallelize(ParallelType::Vectorize); + tv->axis(-2)->parallelize(ParallelType::TIDx); + tv->axis(-3)->parallelize(ParallelType::TIDy); + tv->axis(-4)->parallelize(ParallelType::TIDz); +} + +} // namespace matmul_utils + } // namespace scheduler_utils } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.h b/torch/csrc/jit/codegen/cuda/scheduler/utils.h index 48686e09d959..2460d09e49a6 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.h @@ -248,6 +248,23 @@ struct BroadcastMultiple { // data type size. std::vector getBroadcastMultiples(TensorView* reference_tv); +namespace matmul_utils { + +TORCH_CUDA_CU_API void scheduleContiguousVectorLoad( + TensorView* tv, + GemmTileOptions tile, + int vector_word); + +TORCH_CUDA_CU_API void scheduleWarpTileWithReduction( + TensorView* tv, + GemmTileOptions tile); + +TORCH_CUDA_CU_API void scheduleWarpTileWithNoReduction( + TensorView* tv, + GemmTileOptions tile); + +} // namespace matmul_utils + } // namespace scheduler_utils } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 911bda3da04b..36302b89f153 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -10,6 +10,7 @@ #include #include #include +#include // Cleanup #include @@ -933,6 +934,21 @@ bool TensorView::isEmptyTensor() const { }); } +void TensorView::applyMmaSwizzle(MmaOptions options) { + switch (options.operand) { + case MmaOptions::Operand::NotOperand: + mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(this, options); + break; + case MmaOptions::Operand::A: + case MmaOptions::Operand::B: + mma_util::WarpMmaSwizzler::scheduleOperandRead(this, options); + break; + default: + TORCH_INTERNAL_ASSERT(false, "unknown operand flag"); + break; + } +} + TensorViewBuilder& TensorViewBuilder::ndims(size_t ndims) { TORCH_CHECK(shape_.empty() || shape_.size() == ndims); TORCH_CHECK(contiguity_.empty() || contiguity_.size() == ndims); diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index da7d9443a70e..6b5ef0f5ed91 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -187,6 +187,8 @@ static const char* expr_type2string(ExprType t) { return "BroadcastOp"; case ExprType::WelfordOp: return "WelfordOp"; + case ExprType::MmaOp: + return "MmaOp"; case ExprType::TransposeOp: return "TransposeOp"; case ExprType::ShiftOp: diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index dbd756424f62..8c47e832ba4f 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -88,6 +88,7 @@ enum class ExprType { ReductionOp, BroadcastOp, WelfordOp, + MmaOp, TransposeOp, ShiftOp, GatherOp, From edd43d9a04e8c9a58ef6d5834f14f17f170aedd2 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 21 Feb 2022 16:23:25 -0800 Subject: [PATCH 02/57] mma parallel type && cleanup --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 9 +++++- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 28 ++++++++----------- .../jit/codegen/cuda/lower_validation.cpp | 15 ++++++++++ .../jit/codegen/cuda/scheduler/mma_utils.cpp | 4 ++- torch/csrc/jit/codegen/cuda/type.cpp | 2 ++ torch/csrc/jit/codegen/cuda/type.h | 1 + 6 files changed, 41 insertions(+), 18 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 5e1fc232d986..f0647db94d19 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1071,6 +1071,11 @@ indexMapFromTV( // initial index to zero for unswitch. std::unordered_set zero_loops; + bool within_mma_loops = + std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) { + return fl->iter_domain()->isInstruction(); + }); + for (auto loop : loops) { Val* idx = nullptr; const auto same_parallel_type = @@ -1103,7 +1108,9 @@ indexMapFromTV( // only done when there's a matching domain with the same // parallel type (loop->iter_domain()->isThread() && is_local && - (same_parallel_type || loop->iter_domain()->isInstruction()))) { + (same_parallel_type || + (within_mma_loops && + loop->iter_domain()->getParallelType() == ParallelType::TIDx)))) { idx = GpuLower::current()->kernel()->zeroVal(); zero_loops.insert(loop); } else { diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 90ca0a8a6e5f..57f989d73377 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -309,11 +309,11 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { } private: - Val* const out_; - Val* const in_a_; - Val* const in_b_; - Val* const init_; - c10::optional options_; + Val* const out_ = nullptr; + Val* const in_a_ = nullptr; + Val* const in_b_ = nullptr; + Val* const init_ = nullptr; + c10::optional options_ = c10::nullopt; }; class TORCH_CUDA_CU_API TransposeOp : public Expr { @@ -716,15 +716,7 @@ class TORCH_CUDA_CU_API IterDomain : public Val { return definition() == nullptr; } - bool isInstruction() const { - return is_instruction_; - } - - bool isMmaSwizzled() const { - return is_mma_swizzled_; - } - - //! Used by WarpMmaSwizzler, marks that this id represents a + //! Marks that this id represents a //! instruction loop, mma use only. //! //! An instruction loop can be considered a generalization of @@ -744,8 +736,12 @@ class TORCH_CUDA_CU_API IterDomain : public Val { //! transformed version of above to match the mma swizzle. //! So it's different implicit loopnest for different macros. //! WarpMmaSwizzler will label the instruction loops case-by-case. - void toInstruction() { - is_instruction_ = true; + bool isInstruction() const { + return parallel_type_ == ParallelType::Mma; + } + + bool isMmaSwizzled() const { + return is_mma_swizzled_; } //! Used by WarpMmaSwizzler, this is an utility for WarpMmaSwizzler diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 8f4007bba0bb..867b8eb063bb 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -557,6 +557,21 @@ void validateParallelize(Fusion* fusion) { if (producer->isFusionInput()) { continue; } + + if (is_mma) { + // TODO: lift this check in a follow up + TORCH_INTERNAL_ASSERT( + std::all_of( + producer->domain()->domain().begin() + + producer->getComputeAtPosition(), + producer->domain()->domain().end(), + [](IterDomain* id) { + return id->isMmaSwizzled() || + (id->isBroadcast() && + id->getParallelType() == ParallelType::Serial); + }), + "Temporary check: all id's on right of mma producer CA must be mma swizzled"); + } const auto parallel_bcast_doms = pred_map.getParallelBroadcastDomains(producer); for (const auto i : c10::irange(producer->nDims())) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 8d3f74852edd..593b4aa98adc 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -362,7 +362,9 @@ void WarpMmaSwizzler::scheduleVoltaM16N16K4Fp32Output( if (is_reduction && tv->definition()->isA()) { // Set instruction loops for mma reduce output for (int pos : c10::irange(5)) { - tv->axis(-pos - 1)->toInstruction(); + if (!tv->axis(-pos - 1)->isThread()) { + tv->axis(-pos - 1)->parallelize(ParallelType::Mma); + } tv->axis(-pos - 1)->toMmaSwizzled(); } } diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 6b5ef0f5ed91..e70370e858e0 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -515,6 +515,8 @@ static const char* parallel_type2string(ParallelType t) { return "UR"; case ParallelType::Unswitch: return "US"; + case ParallelType::Mma: + return "MMA"; case ParallelType::Serial: return "S"; default: diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 8c47e832ba4f..815ce28a8813 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -215,6 +215,7 @@ enum class ParallelType { MisalignedVectorize, Unroll, Unswitch, + Mma, Serial }; From ddac459e59a5d828edee496ee370b1bac6073334 Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 22 Feb 2022 00:51:51 -0800 Subject: [PATCH 03/57] cleanup --- torch/csrc/jit/codegen/cuda/codegen.cpp | 12 ++------- torch/csrc/jit/codegen/cuda/index_compute.cpp | 4 +-- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 4 --- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 2 -- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 27 ------------------- torch/csrc/jit/codegen/cuda/lower_utils.h | 2 -- 6 files changed, 4 insertions(+), 47 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index a71a0591a475..93a5312913b1 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -1368,17 +1368,9 @@ class CudaKernelGenerator : private OptOutConstDispatch { case MemoryType::Shared: if (kir::ExpressionEvaluator::isConst(size)) { // Static shared memory + indent() << "__shared__ " << buffer_dtype << " " << varName(tv) + << "[" << genInline(size) << "];\n"; indent(); - - // align to 16B if any access is vectorized - // TODO: - // This is a WAR to support vectorized access, - // eventually want to always use dynamic smem alloc. - if (hasVectorizedAccess(alloc)) { - code_ << "__align__(16) "; - } - code_ << "__shared__ " << buffer_dtype << " " << varName(tv) << "[" - << genInline(size) << "];\n"; } else { // Align Offset Position indent() << "offset = alignBufferSize(offset," diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index f0647db94d19..6f579d7ee09c 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -2083,8 +2083,8 @@ std::vector Index::getNonGlobalConsumerStridedIndices( auto alloc_info = loop_utils::getAllocInformation(consumer_tv, loops); std::unordered_map loop_to_ind_map; std::unordered_set zero_loops; - std::tie(loop_to_ind_map, zero_loops) = indexMapFromTV( - consumer_tv, loops, alloc_info.init_for_loop, true, nullptr); + std::tie(loop_to_ind_map, zero_loops) = + indexMapFromTV(consumer_tv, loops, alloc_info.init_for_loop, true); ensureStaticIndexing(consumer_tv, alloc_info.init_for_loop, loops); diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 57f989d73377..b9ef280b63ca 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -764,7 +764,6 @@ class TORCH_CUDA_CU_API IterDomain : public Val { friend TensorDomain; friend ReplayTransformations; friend IndexReferenceReplay; - friend mma_util::WarpMmaSwizzler; private: //! Valid range is defined as [start:-stop_offset] @@ -782,9 +781,6 @@ class TORCH_CUDA_CU_API IterDomain : public Val { // definitions of split/merge. bool is_simple_ = true; - // Tracks if this id is implicit in an instruction, i.e mma - bool is_instruction_ = false; - //! Tracks if this id represents a thread swizzled loop or //! models an implicit loop within instructions. Should not make //! any changes once an id is warp mapped. diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 66cf2397c405..33fb9f9b71b6 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -849,7 +849,6 @@ IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner) is_rfactor_domain_(src->is_rfactor_domain_), is_padded_dimension_(src->is_padded_dimension_), padded_to_size_(src->padded_to_size_), - is_instruction_(src->is_instruction_), is_mma_swizzled_(src->is_mma_swizzled_) {} bool IterDomain::sameAs(const Statement* other) const { @@ -869,7 +868,6 @@ bool IterDomain::sameAs(const Statement* other) const { is_same = is_same && ScalarCheck::sameAs(start(), other_id->start()); is_same = is_same && ScalarCheck::sameAs(stopOffset(), other_id->stopOffset()); - is_same = is_same && (is_instruction_ == other_id->is_instruction_); return is_same; } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 9a13597f64d1..d878936196f8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -481,33 +481,6 @@ std::vector replaceInputsInExpr( return ReplaceExprInput::replace(exprs, replacement_map); } -bool hasVectorizedAccess(const kir::Allocate* alloc) { - auto buffer_tv = dynamic_cast(alloc->buffer()); - TORCH_INTERNAL_ASSERT( - buffer_tv != nullptr, "checking vectorization on non-tv allocation"); - - for (auto id : buffer_tv->domain()->domain()) { - if (id->getParallelType() == ParallelType::Vectorize) { - return true; - } - } - auto uses = buffer_tv->fusion()->unordered_uses(buffer_tv); - - for (auto use : uses) { - for (auto out : use->outputs()) { - if (auto out_tv = dynamic_cast(out)) { - for (auto id : out_tv->domain()->domain()) { - if (id->getParallelType() == ParallelType::Vectorize) { - return true; - } - } - } - } - } - - return false; -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index a7981e724539..4ed6c25e731a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -137,8 +137,6 @@ std::vector replaceInputsInExpr( const std::vector& exprs, const std::unordered_map& replacement_map); -bool hasVectorizedAccess(const kir::Allocate* alloc); - } // namespace cuda } // namespace fuser } // namespace jit From 2f08d09c7381ee3115910850862f497c547f4924 Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 22 Feb 2022 01:08:42 -0800 Subject: [PATCH 04/57] alignment --- torch/csrc/jit/codegen/cuda/codegen.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 93a5312913b1..e02b2380e789 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -1368,8 +1368,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { case MemoryType::Shared: if (kir::ExpressionEvaluator::isConst(size)) { // Static shared memory - indent() << "__shared__ " << buffer_dtype << " " << varName(tv) - << "[" << genInline(size) << "];\n"; + auto va = kernel_->summary().vectorized_accesses; + if (va.find(tv) != va.end()) { + indent() << " __align__(16) "; + } else { + indent(); + } + code_ << "__shared__ " << buffer_dtype << " " << varName(tv) << "[" + << genInline(size) << "];\n"; indent(); } else { // Align Offset Position From ca77ff4731ea657d04f6b8caea0f2d6c857ff55f Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 22 Feb 2022 01:11:08 -0800 Subject: [PATCH 05/57] comment --- torch/csrc/jit/codegen/cuda/arith.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index 3451bd0c12b5..1e5b4cd30aa3 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -568,7 +568,7 @@ TORCH_CUDA_CU_API TensorView* gather( //! return sum(c, axes) //! //! \param tv_a first multiply operand -//! \param tv_b first multiply operand +//! \param tv_b second multiply operand //! \param axes axes to sum over //! \param init sum initial value //! From 9caeb18e5f19e8ec8c949c1f280393fc4e81847f Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 16 Mar 2022 12:16:50 -0700 Subject: [PATCH 06/57] change request --- test/cpp/jit/test_gpu.cpp | 2 +- test/cpp/jit/test_gpu_tensorcore.cpp | 92 +++++++++++-------- torch/csrc/jit/codegen/cuda/arith.cpp | 3 +- torch/csrc/jit/codegen/cuda/arith.h | 2 +- torch/csrc/jit/codegen/cuda/codegen.cpp | 11 ++- torch/csrc/jit/codegen/cuda/index_compute.cpp | 23 ++--- .../jit/codegen/cuda/ir_interface_nodes.h | 7 +- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 16 ++-- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 11 +++ torch/csrc/jit/codegen/cuda/ir_utils.cpp | 6 +- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 2 +- torch/csrc/jit/codegen/cuda/lower_index.cpp | 4 +- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 4 +- .../jit/codegen/cuda/lower_validation.cpp | 7 +- torch/csrc/jit/codegen/cuda/mutator.cpp | 4 +- .../jit/codegen/cuda/scheduler/mma_utils.cpp | 9 +- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 7 ++ 17 files changed, 128 insertions(+), 82 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index b7125ade5afa..b0e3c4722e0c 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -20630,7 +20630,7 @@ TEST_F(NVFuserTest, FusionBroadcastConcretization4_CUDA) { } #endif -TEST_F(NVFuserTest, FusionIssue1430) { +TEST_F(NVFuserTest, FusionIssue1430_CUDA) { // Derived from an expression sorting issue when using loop map, now expr // sorting uses parallel map. std::unique_ptr fusion_ptr = std::make_unique(); diff --git a/test/cpp/jit/test_gpu_tensorcore.cpp b/test/cpp/jit/test_gpu_tensorcore.cpp index 7f0e0ac40dd9..e05d06c8da90 100644 --- a/test/cpp/jit/test_gpu_tensorcore.cpp +++ b/test/cpp/jit/test_gpu_tensorcore.cpp @@ -131,7 +131,7 @@ TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { // Leaving both sets of mma inputs for volta outside // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b,{1}); + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); fusion.addOutput(tv2); @@ -181,20 +181,22 @@ TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { // Schedule the output instruction tile. // Assumes last 3 dims are mnk - tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); // Set memory type. tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); - + at::manual_seed(0); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); auto t0 = at::randn({16, 4}, options); auto t1 = at::randn({4, 16}, options); FusionExecutor fe; - fe.compileFusion(&fusion,{t0, t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -222,7 +224,7 @@ TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { // Leaving both sets of mma inputs for volta outside // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b,{2}); + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); fusion.addOutput(tv2); @@ -246,9 +248,11 @@ TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -258,7 +262,7 @@ TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { auto t1 = at::randn({16, 4}, options); FusionExecutor fe; - fe.compileFusion(&fusion,{t0, t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); @@ -283,7 +287,7 @@ TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { // Leaving both sets of mma inputs for volta outside // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b,{0}); + auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); fusion.addOutput(tv2); @@ -306,14 +310,16 @@ TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { // To MNK tv0cr->reorder({{0, 2}, {1, 0}, {2, 1}}); tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - + // To MNK tv1cr->reorder({{0, 2}, {1, 0}, {2, 1}}); tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - + tv2c->reorder({{0, 2}, {1, 0}, {2, 1}}); - tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -323,7 +329,7 @@ TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { auto t1 = at::randn({4, 16}, options); FusionExecutor fe; - fe.compileFusion(&fusion,{t0, t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); @@ -353,7 +359,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTT_CUDA) { auto tv0b = broadcast(tv0, {false, false, true}); auto tv1b = broadcast(tv1, {true, false, false}); - auto tv2 = fusedMultiplySum(tv0b, tv1b,{1}); + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); fusion.addOutput(tv2); @@ -538,12 +544,14 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTT_CUDA) { // --------------------------------------------------------------------------- // Use WarpMmaSwizzler for the innermost instruction tile (Mi,Ni, Ki) on // output - tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + // -6 -5 -4 -3 -2 -1 // [Mwo Nwo Mw Nw Mi Ni] - tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + // Inline broadcast with smem write. tv0b->computeAt(tv0cw, -2); tv1b->computeAt(tv1cw, -2); @@ -572,7 +580,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTT_CUDA) { auto t1 = at::randn({K, N}, options); FusionExecutor fe; - fe.compileFusion(&fusion,{t0, t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -652,7 +660,8 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { // Make warp tile: // ------------------------------------------------------------------------- scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); - scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(tv2, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); // -8 -7 -6 -5 -4 -3 -2 -1 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] tv0cr->computeAt(tv2c, -4); @@ -670,7 +679,8 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { tv0r->merge(-2); scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( tv0cw, gemm_tile, 8); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); tv0cw->setMemoryType(MemoryType::Shared); // [Mo,Ko,i,wy,wx,v] @@ -680,17 +690,20 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { // [Mo,No,Ko,i,wy,wx,v] scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( tv1cw, gemm_tile, 8); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); tv1cw->setMemoryType(MemoryType::Shared); // Schedule mma input // --------------------------------------------------------------------------- tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - + // Schedule mma output // --------------------------------------------------------------------------- - tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); tv0b->computeAt(tv0cw, -2); tv1b->computeAt(tv1cw, -2); @@ -717,7 +730,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { auto t1 = at::randn({N, K}, options); FusionExecutor fe; - fe.compileFusion(&fusion,{t0, t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); @@ -745,7 +758,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { // Leaving both sets of mma inputs for volta outside // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b,{0}); + auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); fusion.addOutput(tv2); @@ -796,7 +809,8 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { // Make warp tile: // ------------------------------------------------------------------------- scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); - scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(tv2, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); // -8 -7 -6 -5 -4 -3 -2 -1 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] tv0cr->computeAt(tv2c, -4); @@ -815,7 +829,8 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { tv0r->merge(-2); scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( tv0cw, gemm_tile, 8); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(tv0r, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); tv0cw->setMemoryType(MemoryType::Shared); // [Mo,Ko,i,wy,wx,v] @@ -828,7 +843,8 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { // [Mo,No,Ko,i,wy,wx,v] scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( tv1cw, gemm_tile, 8); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(tv1r, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); tv1cw->setMemoryType(MemoryType::Shared); // Schedule mma input // --------------------------------------------------------------------------- @@ -838,9 +854,11 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- - tv2c->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - tv2->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv0b->computeAt(tv0cw, -2); tv1b->computeAt(tv1cw, -2); @@ -866,7 +884,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { auto t1 = at::randn({K, N}, options); FusionExecutor fe; - fe.compileFusion(&fusion,{t0, t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).t().matmul(t1.to(at::kFloat)); diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 86fe7720e4b1..2ca707c82f72 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -1585,7 +1585,8 @@ TensorView* fusedMultiplySum( TORCH_CHECK(tv_a->nDims() > 0, "Tried to reduce a 0-dim tensor"); // TODO: - // Add tf32 + // Add tf32 and other mma data types + // Add fallback path for non-mma data types. TORCH_CHECK(tv_a->getDataType().value() == DataType::Half); TORCH_CHECK(tv_b->getDataType().value() == DataType::Half); diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index 1e5b4cd30aa3..dbac1f22b0d7 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -574,7 +574,7 @@ TORCH_CUDA_CU_API TensorView* gather( //! //! Note & TODO: //! currently only support lowering to a mma op -//! through this interface. +//! through this interface and only support fp16 inputs. //! will support converting back to multiply and reduce in //! a follow up. TORCH_CUDA_CU_API TensorView* fusedMultiplySum( diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index e02b2380e789..9a41fea15b70 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -361,8 +361,15 @@ class CudaKernelGenerator : private OptOutConstDispatch { size_t vector_word_size = 1; if (uop->out()->isA()) { - if (auto mma = dynamic_cast( - uop->out()->as()->view()->definition())) { + auto out_tv = uop->out()->as()->view(); + if (std::any_of( + out_tv->domain()->domain().begin(), + out_tv->domain()->domain().end(), + [&](IterDomain* id) { return id->isMma(); })) { + auto mma = dynamic_cast( + uop->out()->as()->view()->definition()); + TORCH_INTERNAL_ASSERT( + mma != nullptr, "CodeGen: mma op not in mma loop"); genMmaInitialization(mma, uop); return; } diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 6f579d7ee09c..2c75dc60e3ed 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1038,6 +1038,13 @@ indexMapFromTV( std::unordered_map loop_to_ind_map; + // Check if the current op has an implicit loop implemented + // within an mma instruction. + bool within_mma_loops = + std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) { + return fl->iter_domain()->isMma(); + }); + // When indexed as a producer, the parallel types of the the // producer domains may not be the same as those of the loops, but // that's still valid parallelization. However, in that case, using @@ -1045,7 +1052,8 @@ indexMapFromTV( // with zero isn't valid. That's only valid when there's a matching // IterDomain in the producer tensor that has the same parallel // type. - auto find_matching_parallel_domain = [tv](IterDomain* id) -> bool { + auto find_matching_parallel_domain = + [tv, within_mma_loops](IterDomain* id) -> bool { const auto gpu_lower = GpuLower::current(); auto it = std::find_if( tv->domain()->domain().begin(), @@ -1055,7 +1063,8 @@ indexMapFromTV( // validateParallelize as well. return gpu_lower->caIndexMap().areMapped(id, tv_id) || (gpu_lower->caLoopMap().areMapped(id, tv_id) && - ir_utils::derivedFromRootCAAxes(tv, tv_id)); + ir_utils::derivedFromRootCAAxes(tv, tv_id)) || + (id->getParallelType() == ParallelType::TIDx && within_mma_loops); }); if (it == tv->domain()->domain().end()) { return false; @@ -1071,11 +1080,6 @@ indexMapFromTV( // initial index to zero for unswitch. std::unordered_set zero_loops; - bool within_mma_loops = - std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) { - return fl->iter_domain()->isInstruction(); - }); - for (auto loop : loops) { Val* idx = nullptr; const auto same_parallel_type = @@ -1107,10 +1111,7 @@ indexMapFromTV( // Similarly for local memory tensors, zero replacement can be // only done when there's a matching domain with the same // parallel type - (loop->iter_domain()->isThread() && is_local && - (same_parallel_type || - (within_mma_loops && - loop->iter_domain()->getParallelType() == ParallelType::TIDx)))) { + (loop->iter_domain()->isThread() && is_local && same_parallel_type)) { idx = GpuLower::current()->kernel()->zeroVal(); zero_loops.insert(loop); } else { diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index e796512263a8..4cad54f238a9 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -442,12 +442,7 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! TODO: This step will very likely be removed in a follow up PR. All of //! the options configured here could actually be inferred from fusion IR //! once we are feature complete. - void configureMma(MmaOptions options) { - TORCH_CHECK(definition(), "configureMma: invalid for input tensor ", this); - auto mma = dynamic_cast(definition()); - TORCH_CHECK(mma, "configureMma: invalid for non-mma output: ", this); - mma->configureOptions(options); - } + void configureMma(MmaOptions options); //! Transforms the innermost iterdomains according to the given mma swizzle, //! this should be used on the tvs that are either inputs/outputs of an diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index b9ef280b63ca..f7fa1d92bd6d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -25,12 +25,6 @@ class ViewTransform; class Scope; class IrCloner; -namespace mma_util { - -class WarpMmaSwizzler; - -} // namespace mma_util - //! Returns true if both v1 and v2 are scalars, are the same type of scalars, //! and dispatches to the inherited Val type's `->sameAs` call. e.g. if both //! vals are `Int` will dispatch to v1->as()->sameAs(v2.as()) @@ -274,6 +268,14 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { public: MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init); + MmaOp( + IrBuilderPasskey, + Val* out, + Val* in_a, + Val* in_b, + Val* init, + MmaOptions options); + MmaOp(const MmaOp* src, IrCloner* ir_cloner); Val* out() const { @@ -736,7 +738,7 @@ class TORCH_CUDA_CU_API IterDomain : public Val { //! transformed version of above to match the mma swizzle. //! So it's different implicit loopnest for different macros. //! WarpMmaSwizzler will label the instruction loops case-by-case. - bool isInstruction() const { + bool isMma() const { return parallel_type_ == ParallelType::Mma; } diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 33fb9f9b71b6..b262ebc1bdf1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -556,6 +556,17 @@ MmaOp::MmaOp( addInput(in_b); } +MmaOp::MmaOp( + IrBuilderPasskey passkey, + Val* out, + Val* in_a, + Val* in_b, + Val* init, + MmaOptions options) + : MmaOp(passkey, out, in_a, in_b, init) { + options_ = options; +} + MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner) : Expr(src, ir_cloner), out_(ir_cloner->clone(src->out_)), diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index b43c59207531..ecd9659f6760 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -329,10 +329,8 @@ struct SubstituteInExpr : public OptInDispatch { auto init = reference_->sameAs(mma_expr->init()) ? substitute_->as() : mma_expr->init(); - auto options = mma_expr->options(); - expr_ = - IrBuilder::create(mma_expr->container(), out, in_a, in_b, init); - expr_->as()->configureOptions(options); + expr_ = IrBuilder::create( + mma_expr->container(), out, in_a, in_b, init, mma_expr->options()); } private: diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 03b7087ab64d..51b26ef141de 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -302,7 +302,7 @@ Val* ForLoop::step() const { bool ForLoop::isTrivial() const { // These loops are not materialized if (vectorize() || iter_domain()->isBroadcast() || - iter_domain()->isStride() || iter_domain()->isInstruction()) { + iter_domain()->isStride() || iter_domain()->isMma()) { return true; } diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 717f1e9f0e25..c8ccb4f96ff1 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -375,8 +375,8 @@ void IndexLowering::handle(const MmaOp* mma) { const auto a = lowerSrcIndex(mma->inA(), mma->out()); const auto b = lowerSrcIndex(mma->inB(), mma->out()); const auto out = lowerDstIndex(mma->out()); - auto mma_indexed = IrBuilder::create(out, a, b, mma->init()); - mma_indexed->configureOptions(mma->options()); + auto mma_indexed = + IrBuilder::create(out, a, b, mma->init(), mma->options()); pushBack(mma_indexed); } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index d878936196f8..29c8f7d66796 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -463,8 +463,8 @@ class ReplaceExprInput : private kir::ExprMutator { node->out(), replaced_inputs.value().at(node->inA()), replaced_inputs.value().at(node->inB()), - node->init()); - replacement->configureOptions(node->as()->options()); + node->init(), + node->options()); registerReplaceWithPredicate(node, replacement); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 867b8eb063bb..101ba78e45fe 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -474,11 +474,10 @@ void validateParallelizationOfTensor( tv); // Check that TIDx is exact lane_id const auto& paralel_dim_map = GpuLower::current()->parallelDimensionMap(); + TORCH_INTERNAL_ASSERT( - paralel_dim_map.isExact(ptype), - "TIDx is reserved for lane id in mma kernels, and it needs to be exactly a warp"); - TORCH_INTERNAL_ASSERT( - paralel_dim_map.get(ptype)->getInt().has_value() && + paralel_dim_map.isExact(ptype) && + paralel_dim_map.get(ptype)->getInt().has_value() && paralel_dim_map.get(ptype)->getInt().value() == at::cuda::warp_size(), "TIDx is reserved for lane id in mma kernels, and it needs to be exactly a warp"); diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 0f240c3ad6d4..5958c415d578 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -249,8 +249,8 @@ void OptOutMutator::mutate(MmaOp* mma) { auto container = mma->container(); auto options = mma->options(); container->removeExpr(mma); - auto new_mma = IrBuilder::create(container, out, in_a, in_b, init); - new_mma->configureOptions(options); + auto new_mma = + IrBuilder::create(container, out, in_a, in_b, init, options); } void OptOutMutator::mutate(BroadcastOp* bop) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 593b4aa98adc..770d1d7949c5 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -139,8 +139,15 @@ void checkDimSize( TensorView* tv, std::vector axis, std::vector expect) { + TORCH_INTERNAL_ASSERT( + axis.size() == expect.size(), + "CheckDimSize: Mismatched axis and expect size"); ExpressionEvaluator const_eval(tv->fusion()); for (auto axis_index : c10::irange(axis.size())) { + TORCH_INTERNAL_ASSERT( + ((axis[axis_index]+tv->nDims()) >= 0) && + (axis[axis_index] < (int) tv->nDims()), + "CheckDimSize: axis position out of bound ", axis[axis_index]," ", tv->nDims()); auto id = tv->axis(axis[axis_index]); auto maybe_extent = const_eval.evaluate(id->extent()); TORCH_CHECK( @@ -220,7 +227,7 @@ void validateResultInnerMN(TensorView* tv, int m, int n) { TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( tv->getMaybeRFactorDomain()[root_dim - 2], tv->axis(-2), m)); TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( - tv->getMaybeRFactorDomain()[root_dim - 1], tv->axis(-1), m)); + tv->getMaybeRFactorDomain()[root_dim - 1], tv->axis(-1), n)); } void scheduleVoltaA(TensorView* tv, MmaOptions options) { diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 36302b89f153..35143c824868 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -1013,6 +1013,13 @@ TensorView* TensorViewBuilder::build() const { IrBuilder::create(domain, contiguity_), dtype_); } +void TensorView::configureMma(MmaOptions options) { + TORCH_CHECK(definition(), "configureMma: invalid for input tensor ", this); + auto mma = dynamic_cast(definition()); + TORCH_CHECK(mma, "configureMma: invalid for non-mma output: ", this); + mma->configureOptions(options); +} + } // namespace cuda } // namespace fuser } // namespace jit From de1d3ecb9c0f3a573e2dcaad857810ade71235ba Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 16 Mar 2022 15:03:20 -0700 Subject: [PATCH 07/57] fix same parallel type --- .../cpp/nvfuser/layer_norm_backward.cpp | 3 +- benchmarks/cpp/nvfuser/rms_norm.cpp | 14 +++++---- benchmarks/cpp/nvfuser/rms_norm_backward.cpp | 29 +++++++++---------- torch/csrc/jit/codegen/cuda/index_compute.cpp | 12 ++++---- .../jit/codegen/cuda/scheduler/mma_utils.cpp | 9 ++++-- 5 files changed, 35 insertions(+), 32 deletions(-) diff --git a/benchmarks/cpp/nvfuser/layer_norm_backward.cpp b/benchmarks/cpp/nvfuser/layer_norm_backward.cpp index 5bf6d8c0f993..fe95c01048f2 100644 --- a/benchmarks/cpp/nvfuser/layer_norm_backward.cpp +++ b/benchmarks/cpp/nvfuser/layer_norm_backward.cpp @@ -64,8 +64,7 @@ static void setupLayerNorm_BWD(Fusion* fusion, DataType dtype) { if (dtype != DataType::Float) { layer_norm_results.grad_input = castOp(dtype, layer_norm_results.grad_input); - layer_norm_results.grad_bias = - castOp(dtype, layer_norm_results.grad_bias); + layer_norm_results.grad_bias = castOp(dtype, layer_norm_results.grad_bias); layer_norm_results.grad_weight = castOp(dtype, layer_norm_results.grad_weight); } diff --git a/benchmarks/cpp/nvfuser/rms_norm.cpp b/benchmarks/cpp/nvfuser/rms_norm.cpp index fd93dcc518a6..9c46896366cc 100644 --- a/benchmarks/cpp/nvfuser/rms_norm.cpp +++ b/benchmarks/cpp/nvfuser/rms_norm.cpp @@ -18,7 +18,9 @@ using namespace torch::jit::fuser::cuda; //------------------------------------------------------------------------------ static void setupRMSNorm(Fusion* fusion, DataType dtype) { - TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16); + TORCH_INTERNAL_ASSERT( + dtype == DataType::Float || dtype == DataType::Half || + dtype == DataType::BFloat16); FusionGuard fg(fusion); @@ -54,10 +56,11 @@ static void NvFuserScheduler_RMSNorm( benchmark::State& benchmark_state, FusionExecutorCache* fusion_executor_cache, DataType dtype) { - TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16); + TORCH_INTERNAL_ASSERT( + dtype == DataType::Float || dtype == DataType::Half || + dtype == DataType::BFloat16); - std::vector input_shape{ - 8, benchmark_state.range(0), 1024}; + std::vector input_shape{8, benchmark_state.range(0), 1024}; const float kEps = 1e-6; // inputs @@ -73,8 +76,7 @@ static void NvFuserScheduler_RMSNorm( benchmark_state.SetBytesProcessed( int64_t(benchmark_state.iterations()) * - (2 * input.numel() + weight.numel()) * - int64_t(dataTypeSize(dtype))); + (2 * input.numel() + weight.numel()) * int64_t(dataTypeSize(dtype))); } //------------------------------------------------------------------------------ diff --git a/benchmarks/cpp/nvfuser/rms_norm_backward.cpp b/benchmarks/cpp/nvfuser/rms_norm_backward.cpp index e6578417197c..3bd66b412b97 100644 --- a/benchmarks/cpp/nvfuser/rms_norm_backward.cpp +++ b/benchmarks/cpp/nvfuser/rms_norm_backward.cpp @@ -20,7 +20,9 @@ using namespace torch::jit::fuser::cuda; static void setupRMSNorm_BWD(Fusion* fusion, DataType dtype) { FusionGuard fg(fusion); - TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16); + TORCH_INTERNAL_ASSERT( + dtype == DataType::Float || dtype == DataType::Half || + dtype == DataType::BFloat16); const int kReductionAxis = 2; Double* eps_ptr = IrBuilder::create(1e-6); @@ -47,14 +49,12 @@ static void setupRMSNorm_BWD(Fusion* fusion, DataType dtype) { rstd = castOp(DataType::Float, rstd); } - auto rms_norm_results = rms_norm_backward( - grad_out, input, {1}, rstd, weight, {true, true, true}); + auto rms_norm_results = + rms_norm_backward(grad_out, input, {1}, rstd, weight, {true, true, true}); - if (dtype != DataType::Float ) { - rms_norm_results.grad_input = - castOp(dtype, rms_norm_results.grad_input); - rms_norm_results.grad_weight = - castOp(dtype, rms_norm_results.grad_weight); + if (dtype != DataType::Float) { + rms_norm_results.grad_input = castOp(dtype, rms_norm_results.grad_input); + rms_norm_results.grad_weight = castOp(dtype, rms_norm_results.grad_weight); } fusion->addOutput(rms_norm_results.grad_input); @@ -65,10 +65,11 @@ static void NvFuserScheduler_RMSNorm_BWD( benchmark::State& benchmark_state, FusionExecutorCache* fusion_executor_cache, DataType dtype) { - TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16); + TORCH_INTERNAL_ASSERT( + dtype == DataType::Float || dtype == DataType::Half || + dtype == DataType::BFloat16); - std::vector input_shape{ - 8, benchmark_state.range(0), 1024}; + std::vector input_shape{8, benchmark_state.range(0), 1024}; // inputs at::manual_seed(0); @@ -79,15 +80,13 @@ static void NvFuserScheduler_RMSNorm_BWD( at::Tensor weight = at::randn({input_shape[2]}, options); at::Tensor rstd = at::randn({input_shape[0], input_shape[1], 1}, options); - std::vector aten_inputs( - {grad_out, input, weight, rstd}); + std::vector aten_inputs({grad_out, input, weight, rstd}); runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs); benchmark_state.SetBytesProcessed( int64_t(benchmark_state.iterations()) * - (3 * input.numel() + weight.numel() + - rstd.numel()) * + (3 * input.numel() + weight.numel() + rstd.numel()) * int64_t(dataTypeSize(dtype))); } diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 2c75dc60e3ed..3e8492409b44 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1052,8 +1052,7 @@ indexMapFromTV( // with zero isn't valid. That's only valid when there's a matching // IterDomain in the producer tensor that has the same parallel // type. - auto find_matching_parallel_domain = - [tv, within_mma_loops](IterDomain* id) -> bool { + auto find_matching_parallel_domain = [tv](IterDomain* id) -> bool { const auto gpu_lower = GpuLower::current(); auto it = std::find_if( tv->domain()->domain().begin(), @@ -1063,8 +1062,7 @@ indexMapFromTV( // validateParallelize as well. return gpu_lower->caIndexMap().areMapped(id, tv_id) || (gpu_lower->caLoopMap().areMapped(id, tv_id) && - ir_utils::derivedFromRootCAAxes(tv, tv_id)) || - (id->getParallelType() == ParallelType::TIDx && within_mma_loops); + ir_utils::derivedFromRootCAAxes(tv, tv_id)); }); if (it == tv->domain()->domain().end()) { return false; @@ -1082,8 +1080,10 @@ indexMapFromTV( for (auto loop : loops) { Val* idx = nullptr; - const auto same_parallel_type = - as_consumer || find_matching_parallel_domain(loop->iter_domain()); + const auto same_parallel_type = as_consumer || + find_matching_parallel_domain(loop->iter_domain()) || + (within_mma_loops && + loop->iter_domain()->getParallelType() == ParallelType::TIDx); // See also LoopNestGenerator::pushAlloc. // NOLINTNEXTLINE(bugprone-branch-clone) if (!within_alloc) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 770d1d7949c5..051a8b108428 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -145,9 +145,12 @@ void checkDimSize( ExpressionEvaluator const_eval(tv->fusion()); for (auto axis_index : c10::irange(axis.size())) { TORCH_INTERNAL_ASSERT( - ((axis[axis_index]+tv->nDims()) >= 0) && - (axis[axis_index] < (int) tv->nDims()), - "CheckDimSize: axis position out of bound ", axis[axis_index]," ", tv->nDims()); + ((axis[axis_index] + tv->nDims()) >= 0) && + (axis[axis_index] < (int)tv->nDims()), + "CheckDimSize: axis position out of bound ", + axis[axis_index], + " ", + tv->nDims()); auto id = tv->axis(axis[axis_index]); auto maybe_extent = const_eval.evaluate(id->extent()); TORCH_CHECK( From 74f8c12d7c898266fb25f2ad8c4b908bea1a9c4f Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 16 Mar 2022 16:17:07 -0700 Subject: [PATCH 08/57] move validation pass --- torch/csrc/jit/codegen/cuda/lower2device.cpp | 3 + .../codegen/cuda/lower_sync_information.cpp | 6 + .../jit/codegen/cuda/lower_validation.cpp | 302 +++++------------- .../csrc/jit/codegen/cuda/lower_validation.h | 4 + 4 files changed, 95 insertions(+), 220 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 8886e894171c..335d34180953 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -258,6 +258,9 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { // Want to run this after parallel map is created validateVectorize(fusion_); + // Validate mma data format and compatibility if any on the fusion. + validateMma(fusion_); + // Extract TensorViews that are accessed in a vectorized way and track their // word size. fillVectorizeInfo(); diff --git a/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp b/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp index a10df2bc9fba..629b3236fafc 100644 --- a/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp @@ -178,6 +178,12 @@ void SyncMap::build(Fusion* fusion) { } for (auto parallel_type : kParallelTypeThreads) { + // TIDx is reserved for lane_id in the case of mma ops. + // It is swizzled and handled separately in validateMma. + if (parallel_type == ParallelType::TIDx && expr->isA()) { + continue; + } + auto parallel_type_i = getParallelTypeBitMapOffset(parallel_type); auto p_id = producer_parallel_ids[parallel_type_i]; diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 13f9c2257932..9f1250c3b0de 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -454,226 +454,6 @@ void validateVectorize(Fusion* fusion) { namespace { -// // Validate parallelization of a single tensor -// void validateParallelizationOfTensor( -// TensorView* tv, -// bool is_mma_output = false) { -// // Each ParallelType can be used only once. -// ParallelTypeBitmap pt_map; -// for (size_t i = 0; i < tv->nDims(); ++i) { -// auto axis = tv->axis(i); -// auto ptype = axis->getParallelType(); -// if (!isParallelTypeThread(ptype)) { -// continue; -// } -// if (is_mma_output && ptype == ParallelType::TIDx) { -// TORCH_INTERNAL_ASSERT( -// axis->isMmaSwizzled(), -// "TIDx for mma output is reserved for warp mapping", -// axis, -// tv); -// // Check that TIDx is exact lane_id -// const auto& paralel_dim_map = GpuLower::current()->parallelDimensionMap(); - -// TORCH_INTERNAL_ASSERT( -// paralel_dim_map.isExact(ptype) && -// paralel_dim_map.get(ptype)->getInt().has_value() && -// paralel_dim_map.get(ptype)->getInt().value() == -// at::cuda::warp_size(), -// "TIDx is reserved for lane id in mma kernels, and it needs to be exactly a warp"); - -// auto maybe_dim = paralel_dim_map.get(ptype); -// TORCH_INTERNAL_ASSERT(maybe_dim != nullptr); -// ExpressionEvaluator const_eval(tv->fusion()); -// auto maybe_dim_value = const_eval.evaluate(maybe_dim); -// TORCH_INTERNAL_ASSERT( -// maybe_dim_value.has_value() && -// maybe_dim_value.value() == at::cuda::warp_size(), -// "Mma: TIDx reserved for lane id"); -// } -// // It doesn't matter if this axis is a non-concretized broadcast -// // TODO: merging broadcast and non-broadcast -// if (axis->isBroadcast() && -// !GpuLower::current()->concretizedBroadcastDomains().isConcretized( -// axis)) { -// continue; -// } - -// TORCH_INTERNAL_ASSERT( -// !pt_map.get(ptype), -// "Multiple use of ", -// ptype, -// " in tensor t", -// tv->name(), -// ": ", -// tv); -// pt_map.set(ptype); -// } - -// // If this tensor is predicated by a paralel type, it should not be -// // used to parallelize any domain of this tensor - -// const auto thread_pred = -// GpuLower::current()->threadPredMap().getPredicateInfo(tv); - -// auto predicated_parallel_types = pt_map & thread_pred.limited_types; - -// TORCH_INTERNAL_ASSERT( -// predicated_parallel_types.none(), -// "Invalid parallelization of tensor t", -// tv->name(), -// ". The tensor is parallelized with ", -// predicated_parallel_types.toString(), -// ", but it's invalid to use the types as the tensor is also predicated with them.", -// ", thread pred: ", -// thread_pred.limited_types.toString()); -// } - -// } // namespace - -// void validateParallelize(Fusion* fusion) { -// FUSER_PERF_SCOPE("GpuLower::Lower::validateParallelize"); -// FusionGuard fg(fusion); - -// const auto& par_map = GpuLower::current()->caParallelMap(); -// const auto& loop_map = GpuLower::current()->caLoopMap(); -// const auto& pred_map = GpuLower::current()->threadPredMap(); - -// auto exprs = StmtSort::getExprs(fusion); - -// for (auto expr : exprs) { -// if (!ir_utils::isTvOp(expr)) { -// continue; -// } -// bool is_mma = expr->getExprType().value() == ExprType::MmaOp; -// // Validate parallelization of each consumer by itself -// for (auto consumer : ir_utils::filterByType(expr->outputs())) { -// validateParallelizationOfTensor(consumer, is_mma); -// } -// // Validate parallelization between a producer and a consumer -// for (auto producer : ir_utils::filterByType(expr->inputs())) { -// // Parallelization on input tensors have no effect. -// if (producer->isFusionInput()) { -// continue; -// } - -// if (is_mma) { -// // TODO: lift this check in a follow up -// TORCH_INTERNAL_ASSERT( -// std::all_of( -// producer->domain()->domain().begin() + -// producer->getComputeAtPosition(), -// producer->domain()->domain().end(), -// [](IterDomain* id) { -// return id->isMmaSwizzled() || -// (id->isBroadcast() && -// id->getParallelType() == ParallelType::Serial); -// }), -// "Temporary check: all id's on right of mma producer CA must be mma swizzled"); -// } -// const auto parallel_bcast_doms = -// pred_map.getParallelBroadcastDomains(producer); -// for (const auto i : c10::irange(producer->nDims())) { -// // If a producer axis is threaded, either with threadIdx or -// // blockIdx, there must be a mapped consumer axis with the -// // same ParallelType. An exception is when the producer is -// // allocated on shared memory and its parallelized with -// // threadIdx. In that case, there is no parallelization -// // constraint on the consumer as syncthreads will be inserted -// // when necessary. -// auto producer_axis = producer->axis(i); -// auto producer_ptype = -// par_map.getConcreteMappedID(producer_axis)->getParallelType(); -// if (!isParallelTypeThread(producer_ptype)) { -// continue; -// } -// if (is_mma && producer_ptype == ParallelType::TIDx) { -// TORCH_INTERNAL_ASSERT( -// producer_axis->isMmaSwizzled(), -// "mma input: use WarpMmaMapper to schedule warp input"); -// // Warp mapped ids will not map across mma op since they have -// // different swizzle formats. -// continue; -// } -// // When the producer axis is a broadcast, it is not really -// // parallelized unless thread-predicated -// if (producer_axis->isBroadcast() && -// !parallel_bcast_doms.get(producer_ptype)) { -// continue; -// } -// // No constraint on the consumer tensor when the producer -// // axis is parallelized with threadIdx and allocates on -// // shared memory -// if (isParallelTypeThreadDim(producer_ptype) && -// producer->getMemoryType() == MemoryType::Shared) { -// continue; -// } -// // There should be also nothing to validate when the producer -// // axis is reduction. -// if (producer_axis->isReduction()) { -// continue; -// } -// // There must be a consumer axis that uses the same indexing -// // with the same parallel type as the producer axis. The loop -// // map is used to to find such an axis. Broadcast forwarding -// // does not cause any inconsistent parallelization as indexing -// // takes care of the forwarding. -// for (auto consumer : -// ir_utils::filterByType(expr->outputs())) { -// auto it = std::find_if( -// consumer->domain()->domain().begin(), -// consumer->domain()->domain().end(), -// [&](IterDomain* consumer_axis) { -// return loop_map.areMapped(producer_axis, consumer_axis); -// }); -// TORCH_INTERNAL_ASSERT( -// it != consumer->domain()->domain().end(), -// "Inconsistent parallelization found between TV", -// producer->name(), -// " (", -// producer, -// ") and TV", -// consumer->name(), -// "(", -// consumer, -// "). ", -// "TV", -// consumer->name(), -// " does not have a matching axis for parallelized producer axis, ", -// producer_axis, -// ". CA Map: ", -// loop_map.toString()); -// auto consumer_axis = *it; -// auto consumer_ptype = -// par_map.getConcreteMappedID(consumer_axis)->getParallelType(); -// TORCH_INTERNAL_ASSERT( -// producer_ptype == consumer_ptype, -// "Inconsistent parallelization found between TV", -// producer->name(), -// " (", -// producer, -// ") and TV", -// consumer->name(), -// "(", -// consumer, -// "). " -// "Producer axis, ", -// producer_axis, -// " is parallelized with ", -// stringifyThread(producer_ptype), -// ", but the parallel type of its matching consumer axis, ", -// consumer_axis, -// " is ", -// stringifyThread(consumer_ptype), -// "."); -// } -// } -// } -// } -// } - -// namespace { - // Backward propagation of partial ranges from outputs to // inputs. Necessary to determine required ranges to compute. // @@ -855,6 +635,88 @@ void validatePartialSplit(Fusion* fusion) { } } +namespace { + +void validateMinimumArch(int major, int minor) { + auto prop = at::cuda::getCurrentDeviceProperties(); + TORCH_INTERNAL_ASSERT(prop->major >= major); + if (prop->major == major) { + TORCH_INTERNAL_ASSERT(prop->minor >= minor); + } +} + +void validateMmaTensors(MmaOp* mma) { + bool tidx_validated = false; + std::vector to_validate = { + mma->inA()->as(), + mma->inB()->as(), + mma->out()->as()}; + + for (auto tv : to_validate) { + for (auto id : tv->domain()->domain()) { + auto ptype = id->getParallelType(); + if (ptype == ParallelType::TIDx) { + TORCH_INTERNAL_ASSERT( + id->isMmaSwizzled(), + "TIDx for mma input/output must be set by WarpMmaSwizzler", + id, + tv); + if (!tidx_validated) { + // Check that TIDx is exact lane_id + const auto& paralel_dim_map = + GpuLower::current()->parallelDimensionMap(); + TORCH_INTERNAL_ASSERT( + paralel_dim_map.isExact(ptype) && + paralel_dim_map.get(ptype)->getInt().has_value() && + paralel_dim_map.get(ptype)->getInt().value() == + at::cuda::warp_size(), + "TIDx is reserved for lane id in mma kernels, and it needs to be exactly a warp"); + tidx_validated = true; + } + } + } + } + + // Note: this check will be relaxed in a follow up. + auto validate_operand_ids = [](const TensorView* tv) { + TORCH_INTERNAL_ASSERT( + std::all_of( + tv->domain()->domain().begin() + tv->getComputeAtPosition(), + tv->domain()->domain().end(), + [](IterDomain* id) { + return id->isMmaSwizzled() || + (id->isBroadcast() && + id->getParallelType() == ParallelType::Serial); + }), + "All id's on the right of CA pos needs to be mma-swizzled by WarpMmaSwizzler\n", + tv); + }; + + validate_operand_ids(mma->inA()->as()); + validate_operand_ids(mma->inB()->as()); +} + +} // namespace + +void validateMma(Fusion* fusion) { + auto exprs = StmtSort::getExprs(fusion); + + for (auto expr : exprs) { + if (auto mma = dynamic_cast(expr)) { + validateMmaTensors(mma); + + switch (mma->options().macro) { + case MmaOptions::MacroType::Volta_16_16_4: + validateMinimumArch(7, 0); + break; + default: + TORCH_INTERNAL_ASSERT(false, "validate mma: unsupported macro"); + break; + } + } + } +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index c547981f6561..0f23a551a4aa 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -21,6 +21,10 @@ void validateVectorize(Fusion* fusion); //! calculated that are necessary for output values. void validatePartialSplit(Fusion* fusion); +//! Validate data format and GPU arch compatibility of scheduled +//! mma operators on the fusion. +void validateMma(Fusion* fusion); + } // namespace cuda } // namespace fuser } // namespace jit From db341812c5263145cfadbd0f259cb745ac30bf7c Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 16 Mar 2022 22:25:07 -0700 Subject: [PATCH 09/57] comment and cleanup --- torch/csrc/jit/codegen/cuda/lower_predicate.cpp | 4 ---- torch/csrc/jit/codegen/cuda/lower_validation.cpp | 7 +++++++ torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h | 13 +++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 631caae2b25c..166f38d6cf56 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -157,17 +157,13 @@ class PredicateAnalyzer : public OptOutDispatch { // of the parallelized axis is the actual size of the axis, not // the number of threads. Since the number of threads can be // larger than the axis size, it's not safe to skip predication - bool has_global_access = producer->getMemoryType() == MemoryType::Global || - consumer->getMemoryType() == MemoryType::Global; - bool needs_sharedmem_addr_pred = false; // Check that parallel dimension will not generate out of bound index if (!(producer->getMemoryType() == MemoryType::Local && consumer->getMemoryType() == MemoryType::Local)) { return true; } - bool needs_index_predicate = false; auto pairwise_map = PairwiseRootDomainMap(producer, consumer); auto c2p = BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 9f1250c3b0de..58f61844ed2e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -637,6 +637,8 @@ void validatePartialSplit(Fusion* fusion) { namespace { +//! Utility to make sure targeted gpu capability is +//! higher than provided major.minor. void validateMinimumArch(int major, int minor) { auto prop = at::cuda::getCurrentDeviceProperties(); TORCH_INTERNAL_ASSERT(prop->major >= major); @@ -645,6 +647,9 @@ void validateMinimumArch(int major, int minor) { } } +//! Validates that the operand and result tensors +//! of mma ops are swizzled and also validates +//! specialization of tidx as lane id. void validateMmaTensors(MmaOp* mma) { bool tidx_validated = false; std::vector to_validate = { @@ -698,6 +703,8 @@ void validateMmaTensors(MmaOp* mma) { } // namespace +//! Validate data format and GPU arch compatibility of scheduled +//! mma operators on the fusion. void validateMma(Fusion* fusion) { auto exprs = StmtSort::getExprs(fusion); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h index f98c63961ccd..783da22b46c8 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h @@ -33,12 +33,10 @@ namespace mma_util { //! as follows: //! //! Step 1. Before scheduling, the mma op needs to be configured with a macro -//! type, -//! either manually or inferred (eg. Volta_16_16_4). +//! type, either manually or inferred (eg. Volta_16_16_4). //! //! Step 2. Scheduler can tile the outer dimensions based on any heuristics, -//! i.e. -//! the CTA tiling, warp tiling, splitK etc. +//! i.e. the CTA tiling, warp tiling, splitK etc. //! //! Step 3. The scheduler will need to split the innermost part of the 3 //! involved @@ -98,10 +96,9 @@ namespace mma_util { //! //! 2. Swizzled data input: //! It is also possible that the data input has other swizzles before -//! entering the fusion already and some -//! might be natively compatible with mma format. This is a very broad -//! category of use cases and we'd have to consider enabling any use like -//! this case-by-case. +//! entering the fusion already and some might be natively compatible +//! with mma format. This is a very broad category of use cases +//! and we'd have to consider enabling any use like this case-by-case. class TORCH_CUDA_CU_API WarpMmaSwizzler { public: //! Applies the output mma swizzling to the given tv, should be used From 4dec8276f9c0ad6885bcf3804e182cef5b03cf13 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 16 Mar 2022 23:11:21 -0700 Subject: [PATCH 10/57] lint --- test/cpp/jit/test_gpu_tensorcore.cpp | 2 +- torch/csrc/jit/codegen/cuda/mma_type.cpp | 2 +- torch/csrc/jit/codegen/cuda/mma_type.h | 2 +- torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp | 2 +- torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/cpp/jit/test_gpu_tensorcore.cpp b/test/cpp/jit/test_gpu_tensorcore.cpp index e05d06c8da90..bf06155c0166 100644 --- a/test/cpp/jit/test_gpu_tensorcore.cpp +++ b/test/cpp/jit/test_gpu_tensorcore.cpp @@ -896,4 +896,4 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { } // namespace jit } // namespace torch -#endif \ No newline at end of file +#endif diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index 58280c47fc36..dc77949fba3a 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -134,4 +134,4 @@ std::string toString(MmaOptions::MacroType mt) { } // namespace cuda } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index 407ca3c97725..63e50d51e841 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -129,4 +129,4 @@ std::string toString(MmaOptions::MacroType mt); } // namespace cuda } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 051a8b108428..875a9ea5ab15 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -429,4 +429,4 @@ bool isMmaInitLoop(const kir::ForLoop* loop) { } // namespace cuda } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h index 783da22b46c8..2ee1b4473277 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h @@ -138,4 +138,4 @@ bool isMmaInitLoop(const kir::ForLoop* loop); } // namespace cuda } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch From 5ecd102c40a7ec03217fe883f6f9d5e696e30cce Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 16 Mar 2022 23:37:53 -0700 Subject: [PATCH 11/57] comment and cleanup --- test/cpp/jit/test_gpu_tensorcore.cpp | 12 +++++------ torch/csrc/jit/codegen/cuda/mma_type.cpp | 4 +++- torch/csrc/jit/codegen/cuda/mma_type.h | 10 +++++----- .../csrc/jit/codegen/cuda/scheduler/utils.cpp | 6 +++--- torch/csrc/jit/codegen/cuda/scheduler/utils.h | 20 +++++++++++++++---- 5 files changed, 33 insertions(+), 19 deletions(-) diff --git a/test/cpp/jit/test_gpu_tensorcore.cpp b/test/cpp/jit/test_gpu_tensorcore.cpp index bf06155c0166..4dd74c8566e0 100644 --- a/test/cpp/jit/test_gpu_tensorcore.cpp +++ b/test/cpp/jit/test_gpu_tensorcore.cpp @@ -137,7 +137,7 @@ TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { // TODO: should be able to completely remove it // in a follow up. - GemmTileOptions gemm_tile; + MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(16, 16, 4); gemm_tile.warp_tile = GemmTile(16, 16, 4); gemm_tile.instruction_tile = GemmTile(16, 16, 4); @@ -230,7 +230,7 @@ TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { // TODO: should be able to completely remove it // in a follow up. - GemmTileOptions gemm_tile; + MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(16, 16, 4); gemm_tile.warp_tile = GemmTile(16, 16, 4); gemm_tile.instruction_tile = GemmTile(16, 16, 4); @@ -291,7 +291,7 @@ TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { fusion.addOutput(tv2); - GemmTileOptions gemm_tile; + MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(16, 16, 4); gemm_tile.warp_tile = GemmTile(16, 16, 4); gemm_tile.instruction_tile = GemmTile(16, 16, 4); @@ -363,7 +363,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTT_CUDA) { fusion.addOutput(tv2); - GemmTileOptions gemm_tile; + MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 16, 4); @@ -613,7 +613,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { fusion.addOutput(tv2); - GemmTileOptions gemm_tile; + MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 16, 4); @@ -762,7 +762,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { fusion.addOutput(tv2); - GemmTileOptions gemm_tile; + MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); gemm_tile.warp_tile = GemmTile(64, 64, 32); gemm_tile.instruction_tile = GemmTile(16, 16, 4); diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index dc77949fba3a..3751cdea6bcf 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -5,7 +5,9 @@ namespace jit { namespace fuser { namespace cuda { -MmaBuilder::MmaBuilder(MmaOptions::MacroType macro, GemmTileOptions gemm_tile) { +MmaBuilder::MmaBuilder( + MmaOptions::MacroType macro, + MatMulTileOptions gemm_tile) { option_.macro = macro; // Calculate accumulator stride, will be removed once transpose swizzle ready int outer_stride = gemm_tile.warp_tile.n / gemm_tile.instruction_tile.n; diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index 63e50d51e841..5f42d41ded65 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -22,13 +22,13 @@ struct GemmTile { }; //! Utility data structure for recording gemm tiles -struct TORCH_CUDA_CU_API GemmTileOptions { +struct TORCH_CUDA_CU_API MatMulTileOptions { GemmTile cta_tile = GemmTile(128, 128, 32); GemmTile warp_tile = GemmTile(64, 64, 32); GemmTile instruction_tile = GemmTile(16, 8, 16); - GemmTileOptions() = default; - GemmTileOptions( + MatMulTileOptions() = default; + MatMulTileOptions( GemmTile cta_tile_, GemmTile warp_tile_, GemmTile instruction_tile_) @@ -36,7 +36,7 @@ struct TORCH_CUDA_CU_API GemmTileOptions { warp_tile(warp_tile_), instruction_tile(instruction_tile_) {} - bool operator==(const GemmTileOptions& other) { + bool operator==(const MatMulTileOptions& other) { return cta_tile == other.cta_tile && warp_tile == other.warp_tile && instruction_tile == other.instruction_tile; } @@ -98,7 +98,7 @@ struct MmaOptions { //! User interface generating mma options for mma op class TORCH_CUDA_CU_API MmaBuilder { public: - MmaBuilder(MmaOptions::MacroType macro, GemmTileOptions gemm_tile); + MmaBuilder(MmaOptions::MacroType macro, MatMulTileOptions gemm_tile); MmaBuilder& layout(MmaOptions::MmaInputLayout layout); MmaBuilder& operand(MmaOptions::Operand a_or_b); MmaOptions build() const; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index 7c57047ca9fd..c5fa015930da 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -1385,7 +1385,7 @@ std::vector getBroadcastMultiples(TensorView* reference_tv) { namespace matmul_utils { -void scheduleWarpTileWithReduction(TensorView* tv, GemmTileOptions tile) { +void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { // Assumes // [M, N, K] auto cta_tile = tile.cta_tile; @@ -1421,7 +1421,7 @@ void scheduleWarpTileWithReduction(TensorView* tv, GemmTileOptions tile) { // [Mwo Nwo Ko Mw Nw Mi Ni Ki] } -void scheduleWarpTileWithNoReduction(TensorView* tv, GemmTileOptions tile) { +void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) { // Assumes // [M, N, K] auto cta_tile = tile.cta_tile; @@ -1454,7 +1454,7 @@ void scheduleWarpTileWithNoReduction(TensorView* tv, GemmTileOptions tile) { //! Split the innermost dim to a vectorized load void scheduleContiguousVectorLoad( TensorView* tv, - GemmTileOptions tile, + MatMulTileOptions tile, int vector_word) { auto warp_dims = tile.cta_tile / tile.warp_tile; int num_of_thread = warp_dims.m * warp_dims.n * warp_dims.k * 32; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.h b/torch/csrc/jit/codegen/cuda/scheduler/utils.h index 830c682cf9ce..bce6c79f5513 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.h @@ -251,19 +251,31 @@ struct BroadcastMultiple { std::vector getBroadcastMultiples(TensorView* reference_tv); namespace matmul_utils { - +//! Utilities in this namespace facilitates scheduling matmul kernels with +//! hierarchichal tiling specified in MatMulTileOptions. + +//! Schedule utility for matmul prolog: +//! Use all the threads on a CTA tile to load matmul operands +//! into shared memory with the given vectorization word. +//! TODO: +//! will need to add bank conflict removal swizzle in a follow up. TORCH_CUDA_CU_API void scheduleContiguousVectorLoad( TensorView* tv, - GemmTileOptions tile, + MatMulTileOptions tile, int vector_word); +//! Schedule utility for mma output in matmul main loop: +//! Realize the hierarchical tiling based on the given tiling options. TORCH_CUDA_CU_API void scheduleWarpTileWithReduction( TensorView* tv, - GemmTileOptions tile); + MatMulTileOptions tile); +//! Schedule utility for mma output in matmul main loop: +//! Realize the hierarchical tiling based on the given tiling options +//! on consumers of mma ops in epilog. TORCH_CUDA_CU_API void scheduleWarpTileWithNoReduction( TensorView* tv, - GemmTileOptions tile); + MatMulTileOptions tile); } // namespace matmul_utils From c97c60585dd975f5a112a9e905067d421eb08472 Mon Sep 17 00:00:00 2001 From: shmsong Date: Thu, 17 Mar 2022 14:58:27 -0700 Subject: [PATCH 12/57] comment and format --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 5 +++++ torch/csrc/jit/codegen/cuda/lower2device.cpp | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 0d08a185bd8c..e360c2c1b673 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -871,6 +871,11 @@ indexMapFromTV( Val* idx = nullptr; const auto same_parallel_type = as_consumer || find_matching_parallel_domain(loop->iter_domain()) || + // Note && TODO: + // mma swizzled lane_id does not map naturally from producer + // to consumer but they should still be detected as same + // parallel type. In a follow up may want to extent + // find_matching_parallel_domain to cover this case. (within_mma_loops && loop->iter_domain()->getParallelType() == ParallelType::TIDx); // See also LoopNestGenerator::pushAlloc. diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 639ffc096ec8..de54e7b50434 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -257,7 +257,7 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { // Validate mma data format and compatibility if any on the fusion. validateMma(fusion_); - + // Compute thread predicates. Depends on parallel_dimension_map_ thread_pred_map_.build(fusion_); From 264bc7798a273ecd055a8d230084d481e0eac1a4 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 21 Mar 2022 15:30:11 -0700 Subject: [PATCH 13/57] initial turing and ampere mma support --- caffe2/CMakeLists.txt | 1 + test/cpp/jit/test_gpu_tensorcore.cpp | 1451 +++++++++++++++++ third_party/XNNPACK | 2 +- third_party/fbgemm | 2 +- third_party/flatbuffers | 2 +- third_party/ideep | 2 +- tools/build_variables.bzl | 1 + torch/csrc/jit/codegen/cuda/codegen.cpp | 61 +- .../csrc/jit/codegen/cuda/executor_utils.cpp | 2 + .../jit/codegen/cuda/ir_interface_nodes.h | 22 +- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 5 +- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 6 +- torch/csrc/jit/codegen/cuda/kernel_ir.h | 22 +- .../jit/codegen/cuda/lower_double_buffer.cpp | 17 +- .../jit/codegen/cuda/lower_double_buffer.h | 9 + .../jit/codegen/cuda/lower_insert_syncs.cpp | 29 +- .../csrc/jit/codegen/cuda/lower_predicate.cpp | 118 +- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 58 +- .../jit/codegen/cuda/lower_validation.cpp | 93 +- torch/csrc/jit/codegen/cuda/mma_type.cpp | 41 + torch/csrc/jit/codegen/cuda/mma_type.h | 3 +- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 105 ++ .../jit/codegen/cuda/runtime/tensorcore.cu | 78 +- .../jit/codegen/cuda/scheduler/mma_utils.cpp | 226 +++ .../jit/codegen/cuda/scheduler/mma_utils.h | 6 + torch/csrc/jit/codegen/cuda/tensor_view.cpp | 8 +- torch/csrc/jit/codegen/cuda/type.cpp | 6 + torch/csrc/jit/codegen/cuda/type.h | 11 +- 28 files changed, 2350 insertions(+), 37 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/runtime/memory.cu diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 2a443dfaef69..ef593a3fd608 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -944,6 +944,7 @@ if(USE_CUDA OR USE_ROCM) ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/welford.cu ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/warp.cu ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/tensorcore.cu + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/memory.cu ${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh ${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/UnpackRaw.cuh ) diff --git a/test/cpp/jit/test_gpu_tensorcore.cpp b/test/cpp/jit/test_gpu_tensorcore.cpp index 4dd74c8566e0..50b4269af882 100644 --- a/test/cpp/jit/test_gpu_tensorcore.cpp +++ b/test/cpp/jit/test_gpu_tensorcore.cpp @@ -891,6 +891,1457 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } +TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + + // [M,K] + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); + // [N,K] + auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + tv2->configureMma(mma_builder.build()); + + auto tv0cw = tv0b->cache_after(); + auto tv0cr = tv0cw->cache_after( + mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cache_after(); + auto tv1cr = tv1cw->cache_after( + mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cache_before(); + + // [M,N,K] -> [N,M,K] + tv0cr->reorder({{-2, -3}, {-3, -2}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({8, 16}, options); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + + // [M,K] + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); + // [K,N] + auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,K,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {true, false, false}); + + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TT); + + tv2->configureMma(mma_builder.build()); + + auto tv0cw = tv0b->cache_after(); + auto tv0cr = tv0cw->cache_after( + mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cache_after(); + auto tv1cr = tv1cw->cache_after( + mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cache_before(); + + // [M,K,N] -> [N,M,K] + tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [M,K,N] -> [M,N,K] + tv1cr->reorder({{-2, -1}, {-1, -2}}); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // [M,K,N] -> [M,N,K] + tv2c->reorder({{-2, -1}, {-1, -2}}); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({16, 8}, options); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + + // [K,M] + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); + // [K,N] + auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [K,M,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, true, false}); + auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::NT); + + tv2->configureMma(mma_builder.build()); + + auto tv0cw = tv0b->cache_after(); + auto tv0cr = tv0cw->cache_after( + mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cache_after(); + auto tv1cr = tv1cw->cache_after( + mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cache_before(); + + // [K,M,N] -> [N,M,K] + tv0cr->reorder({{-3, -1}, {-1, -3}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [K,M,N] -> [M,N,K] + tv1cr->reorder({ + {-3, -1}, + {-2, -3}, + {-1, -2}, + }); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // [K,M,N] -> [M,N,K] + tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({16, 8}, options); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTuringGemmTN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 511, N = 257, K = 88; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + tv2->configureMma(mma_builder.build()); + + auto tv0r = tv0->cache_after(); + auto tv1r = tv1->cache_after(); + auto tv0cw = tv0r->cache_after(); + auto tv0cr = tv0cw->cache_after( + mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cache_after(); + auto tv1cr = tv1cw->cache_after( + mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cache_before(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +TEST_F(NVFuserTest, FusionTuringGemmTT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K = 128; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,K,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TT); + + tv2->configureMma(mma_builder.build()); + + auto tv0r = tv0->cache_after(); + auto tv1r = tv1->cache_after(); + auto tv0cw = tv0r->cache_after(); + auto tv0cr = tv0cw->cache_after( + mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cache_after(); + auto tv1cr = tv1cw->cache_after( + mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cache_before(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] -> [No,Ko,K,N] + tv1cw->reorder({{-2, -1}, {-1, -2}}); + tv1r->reorder({{-2, -1}, {-1, -2}}); + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +TEST_F(NVFuserTest, FusionTuringGemmNT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K = 128; + + // [K,M] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [K,M,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, true, false}); + + auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::NT); + + tv2->configureMma(mma_builder.build()); + + auto tv0r = tv0->cache_after(); + auto tv1r = tv1->cache_after(); + auto tv0cw = tv0r->cache_after(); + auto tv0cr = tv0cw->cache_after( + mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cache_after(); + auto tv1cr = tv1cw->cache_after( + mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cache_before(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] -> [..., K,M] + tv0cw->reorder({{-2, -1}, {-1, -2}}); + tv0r->reorder({{-2, -1}, {-1, -2}}); + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] -> [No,Ko,K,N] + tv1cw->reorder({{-2, -1}, {-1, -2}}); + tv1r->reorder({{-2, -1}, {-1, -2}}); + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({K, M}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +TEST_F(NVFuserTest, FusionGemmGemmTuring_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K1 = 128, K2 = 128; + + // Fusion definition (Both gemms are TN) + // [M,K1] + auto tv0 = makeConcreteTensor({M, K1}, DataType::Half); + // [K2,K1] + auto tv1 = makeConcreteTensor({K2, K1}, DataType::Half); + // [N,K2] + auto tv2 = makeConcreteTensor({N, K2}, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + auto tv2b = broadcast(tv2, {true, false, false}); + + // [M,K2,R] + auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); + + auto tv3h = castOp(DataType::Half, tv3); + auto tv3b = broadcast(tv3h, {false, true, false}); + + auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); + + fusion.addOutput(tv4); + + // Fusion: + // Gemm(M,K2,K1) x Gemm(M,N,K2) + + MatMulTileOptions gemm_tile1, gemm_tile2; + + // cta tile: + // To save register, n of cta tile 1 + // matches k of cta tile2 + gemm_tile1.cta_tile = GemmTile(128, 64, 32); + gemm_tile2.cta_tile = GemmTile(128, 32, 64); + + // Distribute to 2x2 warps + gemm_tile1.warp_tile = GemmTile(64, 32, 32); + gemm_tile2.warp_tile = GemmTile(64, 16, 64); + + // Using turing mma macro + gemm_tile2.instruction_tile = GemmTile(16, 8, 16); + gemm_tile2.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder1 = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile1) + .layout(MmaOptions::MmaInputLayout::TN); + + auto mma_builder2 = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile2) + .layout(MmaOptions::MmaInputLayout::TN); + + tv3->configureMma(mma_builder1.build()); + tv4->configureMma(mma_builder2.build()); + + // Global read for gemm 1 + auto tv0r = tv0->cache_after(); + auto tv1r = tv1->cache_after(); + + // Global read for gemm 2 + auto tv2r = tv2->cache_after(); + + // Gemm 1 main loop read + auto tv0cw = tv0r->cache_after(); + auto tv0cr = tv0cw->cache_after(UnaryOpType::LD_MATRIX); + auto tv1cw = tv1r->cache_after(); + auto tv1cr = tv1cw->cache_after(UnaryOpType::LD_MATRIX); + + // Gemm 1 accumulator reg + auto tv3c = tv3->cache_before(); + + // Gemm 2 main loop read + auto tv3cw = tv3h->cache_after(); + auto tv3cr = tv3cw->cache_after(UnaryOpType::LD_MATRIX); + + auto tv2cw = tv2r->cache_after(); + auto tv2cr = tv2cw->cache_after(UnaryOpType::LD_MATRIX); + + // Gemm 2 accumulator reg + auto tv4c = tv4->cache_before(); + + // General idea is inlining gemm1's main loop inside gemm2's + + // Schedule gemm 2: + // ------------------------------------------------------------------ + tv4->split(-2, gemm_tile2.cta_tile.m); + tv4->split(-1, gemm_tile2.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv4->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv2->computeAt(tv4, 2); + tv3->computeAt(tv4, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv4c->split(-1, gemm_tile2.cta_tile.k); + tv4c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv3->computeAt(tv4c, 3); // Implicitly defines cta tile of gemm1 + tv2r->computeAt(tv4c, 3); + + // Make warp tile + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction( + tv4c, gemm_tile2); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv4, gemm_tile2); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv3cr->computeAt(tv4c, -4); + tv2cr->computeAt(tv4c, -4); + + // Schedule tv2 gmem read and smem write: + // ---------------------------------------------------------------- + // [No,Ko,N,K] + tv2cw->merge(-2); + tv2r->merge(-2); + + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv2cw, gemm_tile2, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv2r, gemm_tile2, 8); + tv2cw->setMemoryType(MemoryType::Shared); + + // Schedule tv2 gmem read and smem write: + // ---------------------------------------------------------------- + + // Schedule gemm 2 mma input + // --------------------------------------------------------------------------- + tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv3b->reorder({{-2, -3}, {-3, -2}}); + tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); + + tv2cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); + tv2b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv4c->applyMmaSwizzle( + mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + tv4->applyMmaSwizzle( + mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + + // Schedule gemm 1: + // ------------------------------------------------------------------ + + // CTA tile: + tv0->computeAt(tv3, 2); + tv1->computeAt(tv3, 2); + + // Schedule K dim for gemm 1: + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv3c->split(-1, gemm_tile1.cta_tile.k); + tv3c->reorder({{2, 3}, {3, 4}, {4, 2}}); + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv3c, 3); + tv1r->computeAt(tv3c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction( + tv3c, gemm_tile1); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv3cw, gemm_tile1); + + tv0cr->computeAt(tv3c, -4); + tv1cr->computeAt(tv3c, -4); + + tv3->computeAt(tv3cw, -3); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile1, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile1, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile1, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile1, 8); + tv1cw->setMemoryType(MemoryType::Shared); + + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv3c->applyMmaSwizzle( + mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + tv3cw->applyMmaSwizzle( + mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + tv3h->applyMmaSwizzle( + mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + tv3->applyMmaSwizzle( + mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + tv3cw->setMemoryType(MemoryType::Shared); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + // Gemm 1 + tv3c->axis(3)->parallelize(ParallelType::TIDz); + tv3c->axis(4)->parallelize(ParallelType::TIDy); + + tv3->computeAt(tv3cw, -2); + tv3cw->axis(2)->parallelize(ParallelType::TIDz); + tv3cw->axis(3)->parallelize(ParallelType::TIDy); + + // Gemm 2 + tv4->axis(2)->parallelize(ParallelType::TIDz); + tv4->axis(3)->parallelize(ParallelType::TIDy); + tv4c->axis(3)->parallelize(ParallelType::TIDz); + tv4c->axis(4)->parallelize(ParallelType::TIDy); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::BIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K1}, options); + auto t1 = at::randn({K2, K1}, options); + auto t2 = at::randn({N, K2}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto cg_outputs = fe.runFusion({t0, t1, t2}); + + auto tref = t0.to(at::kFloat) + .matmul(t1.t().to(at::kFloat)) + .matmul(t2.t().to(at::kFloat)); + + // relaxed check for now, err accumulation is significant. + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); +} + +TEST_F(NVFuserTest, FusionGemmSoftmaxGemmTuring_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + + // Omitting outer dimensions and pointwise ops + + const int seql_q = 32; + const int seql_k = 128; + const int hidden_size = 1024; + const int num_heads = 16; + const int head_dim = hidden_size / num_heads; + + // Gemm 1: + // (80, 80, 64) + const int M1 = seql_q, N1 = seql_k, K1 = head_dim; + // (80, 64, 80) + const int M2 = seql_q, N2 = head_dim, K2 = seql_k; + + // Fusion definition (Both gemms are TN) + // [M,K1] + auto inp = makeConcreteTensor({M1, K1}, DataType::Half); + // Query matrix + auto qk = makeConcreteTensor({N1, K1}, DataType::Half); + // Second linear matrix + auto acc = makeConcreteTensor({N2, K2}, DataType::Half); + + fusion.addInput(inp); + fusion.addInput(qk); + fusion.addInput(acc); + + // [M,N,K] + auto tv0b = broadcast(inp, {false, true, false}); + auto tv1b = broadcast(qk, {true, false, false}); + auto tv2b = broadcast(acc, {true, false, false}); + + // [M,K2,R] + auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); + + // Inline define softmax for now for scheduling + auto x = tv3; + const int kReductionAxis = 1; + const int kNumberOfDims = 2; + std::vector broadcast_mask(kNumberOfDims, false); + broadcast_mask[kReductionAxis] = true; + + auto max_val = max(x, {kReductionAxis}); + auto bcast_max = broadcast(max_val, broadcast_mask); + auto x_max_sub = sub(x, bcast_max); + auto exp_val = exp(x_max_sub); + auto sum_exp = sum(exp_val, {kReductionAxis}); + auto bcast_sum = broadcast(sum_exp, broadcast_mask); + auto recip = reciprocal(bcast_sum); + auto tv3sfm = mul(exp_val, recip); + + auto tv3h = castOp(DataType::Half, tv3sfm); + auto tv3b = broadcast(tv3h, {false, true, false}); + auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); + + fusion.addOutput(tv4); + + // Fusion: + // Gemm(M,K2,K1) x Gemm(M,N,K2) + MatMulTileOptions gemm_tile; + + // TODO: use very small tiles for now since + // alias pass is not re-using smem. Fix later. + gemm_tile.cta_tile = GemmTile(32, 128, 32); + + // Distribute to 2x2 warps + gemm_tile.warp_tile = GemmTile(16, 64, 32); + + // Using turing mma macro + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder1 = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + auto mma_builder2 = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + tv3->configureMma(mma_builder1.build()); + tv4->configureMma(mma_builder2.build()); + + // Global read for gemm 1 + auto tv0r = inp->cache_after(); + auto tv1r = qk->cache_after(); + + // Global read for gemm 2 + auto tv2r = acc->cache_after(); + + // Gemm 1 main loop read + auto tv0cw = tv0r->cache_after(); + auto tv0cr = tv0cw->cache_after(UnaryOpType::LD_MATRIX); + auto tv1cw = tv1r->cache_after(); + auto tv1cr = tv1cw->cache_after(UnaryOpType::LD_MATRIX); + + // Gemm 1 accumulator reg + auto tv3c = tv3->cache_before(); + + // Softmax conversion: + auto tv3ccr = tv3->cache_after(); + + // tv3ccr -> tv3h : softmax + + // Gemm 2 main loop read + // auto tv3cw = tv3h->cache_after(); + auto tv3cr = tv3h->cache_after(UnaryOpType::LD_MATRIX); + + auto tv2cw = tv2r->cache_after(); + auto tv2cr = tv2cw->cache_after(UnaryOpType::LD_MATRIX); + + // Gemm 2 accumulator reg + auto tv4c = tv4->cache_before(); + + // Schedule gemm 2: + // ------------------------------------------------------------------ + tv4->split(-2, gemm_tile.cta_tile.m); + tv4->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv4->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + acc->computeAt(tv4, 2); + tv3->computeAt(tv4, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv4c->split(-1, gemm_tile.cta_tile.k); + tv4c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv3->computeAt(tv4c, 2); + tv2r->computeAt(tv4c, 3); + + // Make warp tile + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv4c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv4, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv3cr->computeAt(tv4c, -4); + tv2cr->computeAt(tv4c, -4); + + // Schedule tv2 gmem read and smem write: + // ---------------------------------------------------------------- + // [No,Ko,N,K] + tv2cw->merge(-2); + tv2r->merge(-2); + + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv2cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv2r, gemm_tile, 8); + tv2cw->setMemoryType(MemoryType::Shared); + + // Schedule tv2 gmem read and smem write: + // ---------------------------------------------------------------- + + // Schedule gemm 2 mma input + // --------------------------------------------------------------------------- + tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv3b->reorder({{-2, -3}, {-3, -2}}); + tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); + + tv2cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); + tv2b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv4c->applyMmaSwizzle( + mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + tv4->applyMmaSwizzle( + mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + + // Schedule gemm 1: + // ------------------------------------------------------------------ + + // CTA tile: + // [Mo, Mi128, N80] + + tv3->split(-1, gemm_tile.cta_tile.n); + // [Mo, Mi128, No, Ni128] + + tv3->reorder({{1, 2}, {2, 1}}); + + // [Mo, No, Mi128, Ni128] + inp->computeAt(tv3, 2); + qk->computeAt(tv3, 2); + + // Schedule K dim for gemm 1: + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv3c->split(-1, gemm_tile.cta_tile.k); + tv3c->reorder({{2, 3}, {3, 4}, {4, 2}}); + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv3c, 3); + tv1r->computeAt(tv3c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv3c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv3, gemm_tile); + + tv0cr->computeAt(tv3c, -4); + tv1cr->computeAt(tv3c, -4); + + // tv3->computeAt(tv3cw,-3); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); + + // // Schedule mma output + // // + // --------------------------------------------------------------------------- + tv3c->applyMmaSwizzle( + mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + tv3->applyMmaSwizzle( + mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + + // mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(tv3ccw, + // mma_builder1.build()); + + // Put tv3 result in smem + tv3->setMemoryType(MemoryType::Shared); + + // schedule a reg persistent softmax: from tv3 + // [Mo, M128, RN] + max_val->split(-1, 128); + // [Mo, M128, RN1, RN128] + max_val->split(-1, 4); + // Map to warp (2x2) + max_val->split(-4, 4); + max_val->split(-4, 2); + + // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] + auto max_rf = max_val->rFactor({-1}); + // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] + + // [Mo, M128, RN] + sum_exp->split(-1, 128); + // [Mo, M128, RN1, RN128] + sum_exp->split(-1, 4); + // Map to warp (2x2) + sum_exp->split(-4, 4); + sum_exp->split(-4, 2); + + // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] + auto sum_exp_rf = sum_exp->rFactor({-1}); + // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] + + exp_val->computeAt(sum_exp_rf, 4); + exp_val->split(-1, 128); + exp_val->split(-1, 4); + bcast_max->computeAt(exp_val, -2); + + // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] + + // Read from smem + tv3ccr->computeAt(max_rf, 4); + // [Mo, Mo32, My2, Mx2, N80] + tv3ccr->split(-1, 128); + tv3ccr->split(-1, 4); + // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] + + // Write to second gemm + tv3h->split(-1, 128); + tv3h->split(-1, 4); + // Map to warp (2x2) + tv3h->split(-4, 4); + tv3h->split(-4, 2); + + bcast_sum->computeAt(tv3h, -2); + + tv3h->setMemoryType(MemoryType::Shared); + + // Parallelize + tv4->axis(0)->parallelize(ParallelType::BIDx); + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + // Gemm 1 + tv3c->axis(3)->parallelize(ParallelType::TIDz); + tv3c->axis(4)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDz); + tv3->axis(3)->parallelize(ParallelType::TIDy); + + auto parallelize_non_reduced_val = [](TensorView* tv) { + tv->axis(-2)->parallelize(ParallelType::TIDx); + tv->axis(2)->parallelize(ParallelType::TIDz); + tv->axis(3)->parallelize(ParallelType::TIDy); + }; + + auto parallelize_reduced_val = [](TensorView* tv) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + tv->axis(2)->parallelize(ParallelType::TIDz); + tv->axis(3)->parallelize(ParallelType::TIDy); + }; + + parallelize_non_reduced_val(tv3h); + parallelize_non_reduced_val(max_rf); + parallelize_non_reduced_val(bcast_max); + parallelize_non_reduced_val(exp_val); + parallelize_non_reduced_val(sum_exp_rf); + parallelize_non_reduced_val(bcast_sum); + parallelize_non_reduced_val(recip); + + parallelize_reduced_val(max_val); + parallelize_reduced_val(sum_exp); + + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + // Gemm 2 + tv4->axis(2)->parallelize(ParallelType::TIDz); + tv4->axis(3)->parallelize(ParallelType::TIDy); + tv4c->axis(3)->parallelize(ParallelType::TIDz); + tv4c->axis(4)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M1, K1}, options); + auto t1 = at::randn({N1, K1}, options); + auto t2 = at::randn({N2, K2}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto cg_outputs = fe.runFusion({t0, t1, t2}); + + auto g1 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + auto sg1 = at::_softmax(g1, -1, false); + auto gsg1 = sg1.matmul(t2.t().to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001)); +} + +TEST_F(NVFuserTest, FusionAmpereGemmTN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + + int M = 255, N = 511, K = 88; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + tv2->configureMma(mma_builder.build()); + + auto tv0cw = tv0->cache_after(UnaryOpType::CP_ASYNC); + auto tv0cr = tv0cw->cache_after(UnaryOpType::LD_MATRIX); + auto tv1cw = tv1->cache_after(UnaryOpType::CP_ASYNC); + auto tv1cr = tv1cw->cache_after(UnaryOpType::LD_MATRIX); + auto tv2c = tv2->cache_before(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0cw->computeAt(tv2c, 3); + tv1cw->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] + tv1cw->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + // [... Mi, Ni, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace jit diff --git a/third_party/XNNPACK b/third_party/XNNPACK index ae108ef49aa5..79cd5f9e18ad 160000 --- a/third_party/XNNPACK +++ b/third_party/XNNPACK @@ -1 +1 @@ -Subproject commit ae108ef49aa5623b896fc93d4298c49d1750d9ba +Subproject commit 79cd5f9e18ad0925ac9a050b00ea5a36230072db diff --git a/third_party/fbgemm b/third_party/fbgemm index 1ddff63cd3a9..7588d9d80482 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 1ddff63cd3a99bdd8f52e8147dbfe723522d2f48 +Subproject commit 7588d9d804826b428fc0e4fd418e9cc3f7a72e52 diff --git a/third_party/flatbuffers b/third_party/flatbuffers index d0cede9c90c5..697147a2e686 160000 --- a/third_party/flatbuffers +++ b/third_party/flatbuffers @@ -1 +1 @@ -Subproject commit d0cede9c90c5257537c293517a21376408b549fa +Subproject commit 697147a2e686486424b9d15fc3e1612586a60f97 diff --git a/third_party/ideep b/third_party/ideep index 4a56ab2c3f61..82aac435b5ec 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 4a56ab2c3f61c44e0f8ea241beeb732b7d70dc5b +Subproject commit 82aac435b5ecfec0855d0d72b84aee3ed0e72813 diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index eb7026604041..589d22359f68 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -58,6 +58,7 @@ libtorch_nvfuser_runtime_sources = [ "torch/csrc/jit/codegen/cuda/runtime/helpers.cu", "torch/csrc/jit/codegen/cuda/runtime/index_utils.cu", "torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu", + "torch/csrc/jit/codegen/cuda/runtime/memory.cu", "torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu", "torch/csrc/jit/codegen/cuda/runtime/tensor.cu", "torch/csrc/jit/codegen/cuda/runtime/tuple.cu", diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 2287b2835ee6..f58b36de8daf 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -458,6 +458,34 @@ class CudaKernelGenerator : private OptOutConstDispatch { TORCH_INTERNAL_ASSERT(false, "Unreachable"); } + std::string genVectorPointer(Val* val, DataType dtype, int vec_size) { + std::stringstream ss; + + ss << "reinterpret_cast*>(&" + << gen(val) << ")"; + + return ss.str(); + } + + void genCpAsync(const UnaryOp* uop, int vec_size) { + auto dtype = uop->in()->getDataType().value(); + + indent() << "Ampere::cpAsync(" + << genVectorPointer(uop->out(), dtype, vec_size) << "," + << genVectorPointer(uop->in(), dtype, vec_size) << ");\n"; + } + + void genLdMatrix(const UnaryOp* uop, int vector_word_size) { + auto dtype = uop->in()->getDataType().value(); + indent() << "Turing::ldMatrix"; + if (uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT) { + code_ << "T"; + } + code_ << " ("; + code_ << "*" << genVectorPointer(uop->out(), dtype, vector_word_size) << "," + << "&" << gen(uop->in()) << ");\n"; + } + void handle(const UnaryOp* uop) final { bool is_vector_op = false; size_t vector_word_size = 1; @@ -505,7 +533,10 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (vectorize_op) { TORCH_INTERNAL_ASSERT( - uop->getUnaryOpType() == UnaryOpType::Set, + uop->getUnaryOpType() == UnaryOpType::Set || + uop->getUnaryOpType() == UnaryOpType::CP_ASYNC || + uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || + uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT, "Cannot vectorize operations that are not sets. ", "Use cache_before and cache_after to store/load with vectorized reads into buffers."); is_vector_op = true; @@ -522,6 +553,20 @@ class CudaKernelGenerator : private OptOutConstDispatch { } if (is_vector_op) { + // Note: Non-vectorized cp async isn't yet supported. + // will support in a follow up. + if (uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { + genCpAsync(uop, vector_word_size); + return; + } + + // TODO: do we want to define a unary op called memory/copy/ldst? + if (uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || + uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT) { + genLdMatrix(uop, vector_word_size); + return; + } + auto out_tv = uop->out()->as()->view(); if (uop->in()->isScalar()) { // Note: @@ -896,7 +941,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (init) { ss << "init"; } - ss << toString(options.macro) << toString(options.operand_layout); + ss << toString(options.macro); + + if (isVolta(options.macro)) { + ss << toString(options.operand_layout); + } else if (isTuring(options.macro) || isAmpere(options.macro)) { + // mma's in turing and ampere TN only. + ss << "TN"; + } // TODO: additional parameter could be removed by swizzling iterdomain auto acc_stride = mma->accStride(); TORCH_INTERNAL_ASSERT(acc_stride > 0); @@ -1804,7 +1856,10 @@ class CudaKernelGenerator : private OptOutConstDispatch { } } - void handle(const kir::BlockSync*) final { + void handle(const kir::BlockSync* sync) final { + if (sync->syncGmem()) { + indent() << "Ampere::cpAsyncBarrier();\n"; + } // Use a custom synchronization method if enabled if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) { indent() << "block_sync::sync();\n"; diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index d81ce7b2c55c..cf7319a1e74b 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -98,6 +99,7 @@ std::string kernelPreamble() { ss << nvfuser_resources::welford_cu; ss << nvfuser_resources::warp_cu; ss << nvfuser_resources::tensorcore_cu; + ss << nvfuser_resources::memory_cu; ss << nvfuser_resources::fused_reduction_cu; // Random utilities diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index d3b7dcb33764..9e93870bcaed 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -404,14 +404,20 @@ class TORCH_CUDA_CU_API TensorView : public Val { TensorView* var, TensorView* n); - // Create a TensorView before the original tensor. A common use case is to - // write results into shared memory or registers before moving to global - // memory. Analogous to TVM Cache_Write - TensorView* cache_before(); - - // Create a TensorView after the original tensor. A common use case is to - // read tensor into shared memory or registers. Analogous to TVM Cache_Read - TensorView* cache_after(); + //! Create a TensorView before the original tensor. A common use case is to + //! write results into shared memory or registers before moving to global + //! memory. Analogous to TVM Cache_Write + //! + //! @param cache_op: memory operator to use for the inserted op between + //! the the data tensor and the cache tensor + TensorView* cache_before(UnaryOpType cache_op = UnaryOpType::Set); + + //! Create a TensorView after the original tensor. A common use case is to + //! read tensor into shared memory or registers. Analogous to TVM Cache_Read + //! + //! @param cache_op: memory operator to use for the inserted op between + //! the the data tensor and the cache tensor + TensorView* cache_after(UnaryOpType cache_op = UnaryOpType::Set); // For a fusion output with other uses, we want to avoid writing to global // memory and then reading the output again. We write to global memory diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 44f2e29df5e9..e9ab0cdfe4a0 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1096,8 +1096,9 @@ void IterDomain::parallelize(ParallelType t) { if (isMmaSwizzled()) { TORCH_CHECK( - t == ParallelType::Vectorize, - "Parallel type other than vectorize not allowed for warp mapped ids"); + t == ParallelType::Vectorize || t == ParallelType::TIDx || + t == ParallelType::Serial, + "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids"); } } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 46fdc78aade7..92678c8aae0f 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -78,8 +78,10 @@ TensorIndex::TensorIndex( } } -BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) - : Expr(passkey, ExprType::BlockSync), war_sync_(war_sync) { +BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync, bool gmem_sync) + : Expr(passkey, ExprType::BlockSync), + war_sync_(war_sync), + gmem_sync_(gmem_sync) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index bc714e5d87e4..e52cf18e3de0 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -244,15 +244,35 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { // class TORCH_CUDA_CU_API BlockSync final : public Expr { public: - explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false); + explicit BlockSync( + IrBuilderPasskey passkey, + bool war_sync = false, + bool gmem_sync = false); bool isWarHazardSync() const { return war_sync_; } + bool syncGmem() const { + return gmem_sync_; + } + private: // TODO: war_sync_ is only used for testing/validation purposes. bool war_sync_ = false; + + //! Indicates if this block sync also synchronizes with asynchronous + //! copy ops from global mem to shared mem. + //! Currently making this a parameter in BlockSync so that we always + //! sync the blocks after synchronizing with the async gmem loads to + //! avoid having to model more complex sync patterns. + //! For most cases it could be assumed that output of async gmem op + //! in the shared memory will be accessed by other threads as well + //! since otherwise wouldn't need to load to shared mem. + //! An exception might be if we use shared mem as extension of + //! register space and we could model this sync as a separate op + //! if we see a use case. + bool gmem_sync_ = false; }; // Synchronize all blocks in device, implies cooperative group launch is diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index 571ba62a545b..5c739f625c44 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -68,7 +68,10 @@ void validateDoubleBufferedTensor(const TensorView* tv) { auto def = tv->definition(); TORCH_INTERNAL_ASSERT( def->isA() && - def->as()->getUnaryOpType() == UnaryOpType::Set, + def->as()->getUnaryOpType() == UnaryOpType::Set || + def->as()->getUnaryOpType() == UnaryOpType::CP_ASYNC || + def->as()->getUnaryOpType() == UnaryOpType::LD_MATRIX || + def->as()->getUnaryOpType() == UnaryOpType::LD_MATRIXT, "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set is supported: ", def->toString()); @@ -122,6 +125,14 @@ class DoubleBufferFusionInspector : private IterVisitor { return; } + TORCH_INTERNAL_ASSERT( + tv->definition(), "Fusion input shouldn't be double buffered.", tv); + if (auto uop = dynamic_cast(tv->definition())) { + if (uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { + db_info_.setAsyncGmemDoubleBuffer(); + } + } + validateDoubleBufferedTensor(tv); auto db_axis = getDoubleBufferAxis(tv); @@ -407,7 +418,9 @@ class DoubleBufferInserter : private kir::ExprMutator { // RAW sync is not inserted for double buffered tensors. The only // exception is the prologue load. if (write_to_smem) { - auto sync = IrBuilder::create(); + auto sync = IrBuilder::create( + false, + GpuLower::current()->doubleBufferInfo().hasAsyncGmemDoubleBuffer()); registerInsertBefore(double_buffer_loop, sync); } diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h index 96bc247f4ff6..731d20eca2b2 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h @@ -128,12 +128,21 @@ class TORCH_CUDA_CU_API DoubleBufferInfo { Val* getOriginalAllocSize(const TensorView* tv); + void setAsyncGmemDoubleBuffer() { + has_async_gmem_double_buffer_ = true; + } + + bool hasAsyncGmemDoubleBuffer() { + return has_async_gmem_double_buffer_; + } + private: TvInfo& getTvInfo(const TensorView* tv); private: //! Keeps track of information for lowering double buffered tensors std::unordered_map map_; + bool has_async_gmem_double_buffer_ = false; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index 1acf33150cc4..74bb4c9d9e0c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -134,6 +134,8 @@ class WarSyncInserter : private kir::ExprMutator { for (const auto& entry : lower_alloc_info_map) { alloc_map_.insert(entry.first); } + need_to_sync_gmem_ = + GpuLower::current()->doubleBufferInfo().hasAsyncGmemDoubleBuffer(); kir::ExprMutator::traverseAndInsert(exprs); } @@ -275,7 +277,8 @@ class WarSyncInserter : private kir::ExprMutator { // WAR Sync is necessary in this loop, register its insertion. if (insert_sync) { - auto sync_expr = IrBuilder::create(true); + auto sync_expr = + IrBuilder::create(true, need_to_sync_gmem_); kir::ExprMutator::registerInsertAfter( for_loop->body().exprs().back(), sync_expr, &for_loop->body()); handle(sync_expr); @@ -332,6 +335,8 @@ class WarSyncInserter : private kir::ExprMutator { // alias tv, each aliased tv in a unique ca_loop has to be tracked separately // for WAR insertion. std::unordered_map> smem_allocations_; + + bool need_to_sync_gmem_ = false; }; class ExprFlattener : private kir::IrVisitor { @@ -519,7 +524,10 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { sync_expr = IrBuilder::create( sync_bitmap, maybe_alloc->buffer()); } else { - sync_expr = IrBuilder::create(); + bool need_gmem_sync = gmem_sync_after_.count(expr); + auto sync_expr = IrBuilder::create( + false, // is not war sync + need_gmem_sync); } if (out_tv->getComputeAtPosition() == 0) { // Sync should be placed at global scope, after its outer most loop if @@ -661,6 +669,17 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { } auto last_smem_writes = isModifiedSharedMemory(smem, expr->inputs()); + + bool need_gmem_sync = std::any_of( + last_smem_writes.begin(), last_smem_writes.end(), [](Expr* expr) { + if (auto uop = dynamic_cast(expr)) { + if (uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { + return true; + } + } + return false; + }); + if (!last_smem_writes.empty()) { TORCH_INTERNAL_ASSERT( prev_tv_expr != nullptr, @@ -671,6 +690,9 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { bitmap.set(ParallelType::TIDz); sync_after_.emplace_back(std::make_pair(prev_tv_expr, bitmap)); last_writes_.push_back(last_smem_writes); + if (need_gmem_sync) { + gmem_sync_after_.insert(prev_tv_expr); + } smem.clear(); } @@ -710,6 +732,9 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { //! it is not placed before those write expressions. std::deque> last_writes_; + //! Keep track of expressions that must be followed by a gmem sync + std::unordered_set gmem_sync_after_; + public: static std::vector insert(const std::vector& loop_nests) { ReadAfterWriteSyncs inserter(loop_nests); diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 166f38d6cf56..6a1224573146 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -50,9 +50,53 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { setWritePredicate(expr, conditional); } + if (isPredicatedGmemToSharedInit(expr)) { + invertPredicateForGmemToSharedMemInitialize(expr); + } + kir::IrVisitor::handle(expr); } + void invertPredicateForGmemToSharedMemInitialize(Expr* expr) { + auto pred = expr->predicate()->value(); + auto invert = SimplifyingIrBuilder::notExpr(pred); + expr->predicate()->setValue(invert->as()); + } + + bool isPredicatedGmemToSharedInit(Expr* expr) { + if (!expr->predicate() || + expr->predicate()->predicate_type() != PredicateType::Vectorize) { + return false; + } + if (auto ite = dynamic_cast(expr)) { + if (!ite->elseBody().empty()) { + return false; + } + if (ite->predicate()->predicate_type() != PredicateType::Vectorize) { + return false; + } + for (auto inner_expr : ite->thenBody().exprs()) { + // Checking if this expr writes scalar to shared mem tensor, + // while fuser expr read from gmem to shared mem. + if (auto kir_uop = dynamic_cast(inner_expr)) { + if (auto out_tv = dynamic_cast(kir_uop->out())) { + auto fuser_expr = out_tv->view()->definition(); + TORCH_INTERNAL_ASSERT( + fuser_expr != nullptr, + "kir tv expr without corresponding fuser expr defining the output"); + if (auto in_fuser_tv = + dynamic_cast(fuser_expr->input(0))) { + return in_fuser_tv->getMemoryType() == MemoryType::Global && + out_tv->view()->getMemoryType() == MemoryType::Shared && + kir_uop->in()->isConstScalar(); + } + } + } + } + } + return false; + } + void setWritePredicate(Expr* expr, Bool* read_cond) { if (expr->writePredicate() != nullptr) { auto write_cond = generateConditional(expr->writePredicate()); @@ -127,6 +171,13 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { }; void assertOnWarpOps(const Expr* expr) { + if (auto uop = dynamic_cast(expr)) { + TORCH_INTERNAL_ASSERT( + uop->getUnaryOpType() != UnaryOpType::LD_MATRIX || + uop->getUnaryOpType() != UnaryOpType::LD_MATRIXT, + "Predicate elimination: cannot eliminate pred for ldmatrix, use exact parallel dims", + expr); + } TORCH_INTERNAL_ASSERT( !expr->isA(), "Mma op: cannot eliminate predicate for mma op, tiling not valid"); @@ -159,8 +210,61 @@ class PredicateAnalyzer : public OptOutDispatch { // larger than the axis size, it's not safe to skip predication // Check that parallel dimension will not generate out of bound index - if (!(producer->getMemoryType() == MemoryType::Local && - consumer->getMemoryType() == MemoryType::Local)) { + if (producer->getMemoryType() == MemoryType::Global || + consumer->getMemoryType() == MemoryType::Global) { + return true; + } + + bool needs_sharedmem_addr_pred = false; + if (producer->getMemoryType() == MemoryType::Shared || + consumer->getMemoryType() == MemoryType::Shared) { + // Indexing is based on consumer leaf ids so check the consumer. + auto& parallel_dimension_map = + GpuLower::current()->parallelDimensionMap(); + for (auto id : consumer->domain()->domain()) { + if (id->isThreadDim()) { + auto ptype = id->getParallelType(); + if (!parallel_dimension_map.isExact(ptype)) { + needs_sharedmem_addr_pred = true; + } + } + + // TODO: (Address in a follow up) + // predicate removal with init breaks unroll and unswitch, eg. as in + // issue 1133 disabling this usage for now + if (id->getParallelType() == ParallelType::Unroll || + id->getParallelType() == ParallelType::Unswitch) { + needs_sharedmem_addr_pred = true; + } + + // TODO: (Address in a follow up) + // This cannot yet be removed since smem initialization needs to be + // handled specially, e.g. as in smem_reduce test. Will be able to + // lift this one once the generic pred removal pass is ready. + auto consumer_def = consumer->definition(); + if (consumer_def->isA() || + consumer_def->isA()) { + if (producer->getMemoryType() == MemoryType::Shared) { + needs_sharedmem_addr_pred = true; + } + } + } + } + + // Restricted usage for cp async. + // TODO: + // cp async initialization and data load path are async + // so would need support for multiple predicate to handle + // generic support. Currently only support the use cases where + // the smem address access will never be out of bound. + auto consumer_def = dynamic_cast(consumer->definition()); + if (consumer_def && + consumer_def->getUnaryOpType() == UnaryOpType::CP_ASYNC) { + TORCH_INTERNAL_ASSERT( + !needs_sharedmem_addr_pred, "unsupported usage for cp async"); + } + + if (needs_sharedmem_addr_pred) { return true; } @@ -307,6 +411,16 @@ bool PredicateElimination::needsPredicate(Expr* expr) const { }); }); + if (auto uop = dynamic_cast(expr)) { + if (uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || + uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT || + uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { + TORCH_INTERNAL_ASSERT( + !filters.back()(expr), + "No support yet for shared mem primitives with shift"); + } + } + // Predicates the expression if any producer-consumer pair of the // expression needs to be predicated filters.emplace_back([](Expr* expr) { diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index c4f926131a8a..ff9fc411db6b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -109,7 +109,63 @@ void UnrollPass::handle(Expr* expr) { // Vectorized expressions should never use inline predicates kir::Predicate* pred = nullptr; - if (!unswitched_loop_ && + + auto is_ld_matrix_op = [](Expr* expr) { + if (auto uop = dynamic_cast(expr)) { + if (uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || + uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT) { + return true; + } + } + return false; + }; + + bool is_ld_matrix = false; + bool is_ld_matrix_producer = false; + if (is_ld_matrix_op(expr)) { + // Note & TODO: currently not support predicated ldmatrix since it'd + // need to make sure all threads in a warp evaluate the predicate + // to the same value. + // Currently asserting that the predicate for ldmatrix and its producer + // can be omitted. Would need to build out support for warp level + // predicates if we want vectorization predicates to be active on + // these ops. + TORCH_INTERNAL_ASSERT( + GpuLower::current()->predicateElimination().canOmitPredicate(expr), + "Vectorize: unsupported predicate for warp ops"); + is_ld_matrix = true; + } + + // Note & TODO: This part is a WAR that disables any predicated + // producer of ldmatrix consumers. Ldmatrix currently cannot + // be predicated due to no support in warp-uniform predicates. + // A predicated producer will result in ldmatrix propagating + // data outside of the producer's valid region. + // Except for vectorized global loads which should be properly + // initialized outside of the predicated region. + if (!std::any_of( + expr->inputs().begin(), expr->inputs().end(), [](Val* val) { + return val->isA() && + val->as()->getMemoryType() == MemoryType::Global; + })) { + for (auto out_tv : ir_utils::filterByType(expr->outputs())) { + for (auto use : expr->fusion()->unordered_uses(out_tv)) { + if (is_ld_matrix_op(use)) { + TORCH_INTERNAL_ASSERT( + GpuLower::current()->predicateElimination().canOmitPredicate( + expr), + "Vectorize: producers of ldmatrix cannot be predicated except for global loads"); + is_ld_matrix_producer = true; + break; + } + } + } + } + + const bool omit_vector_predicate_for_ldmatrix = + is_ld_matrix || is_ld_matrix_producer; + + if (!unswitched_loop_ && !omit_vector_predicate_for_ldmatrix && std::any_of( for_loops_.begin(), for_loops_.end(), [](const kir::ForLoop* fl) { return fl->iter_domain()->getParallelType() == diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index df828b96dce0..865ef5d47e1e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -592,11 +592,17 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) { } } if (has_vectorize_dim) { + auto uop_type = UnaryOpType::Abs; + if (tv->definition() && tv->definition()->isA()) { + uop_type = tv->definition()->as()->getUnaryOpType(); + } TORCH_INTERNAL_ASSERT( tv->definition() == nullptr || (tv->definition()->isA() && - tv->definition()->as()->getUnaryOpType() == - UnaryOpType::Set), + (uop_type == UnaryOpType::Set || + uop_type == UnaryOpType::CP_ASYNC || + uop_type == UnaryOpType::LD_MATRIX || + uop_type == UnaryOpType::LD_MATRIXT)), "Vectorized accesses cannot be inline with computation, they are only supported with a Set operation.", "TensorView: ", tv); @@ -860,6 +866,79 @@ void validateMmaTensors(MmaOp* mma) { validate_operand_ids(mma->inB()->as()); } +//! Note and TODO: +//! Currently relying on ldmatrix to +//! obtain the correct data layout for turing/ampere +//! mma's. +//! This restriction will eventually not +//! be necessary once the scatter swizzle is ready. +void validateTuringInput(TensorView* tv) { + // Currently only allowed input paths are: + // ldmatrix -> mma or + // ldmatrix -> broadcast -> mma + // We actually wouldn't want too much flexibility here since + // this path is very perf critical. But the check itself + // can be made cleaner once we have the correct swizzle + // labeling. + // The most generic support would involve build out to + // support any pointwise ops that does not change the + // datalayout. + auto tv_def = tv->definition(); + TORCH_INTERNAL_ASSERT(tv_def); + if (tv_def->isA()) { + tv_def = tv_def->input(0)->definition(); + } + TORCH_INTERNAL_ASSERT(tv_def); + auto tv_uop = dynamic_cast(tv_def); + TORCH_INTERNAL_ASSERT(tv_uop); + TORCH_INTERNAL_ASSERT( + tv_uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || + tv_uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT); +} + +// Output of ldmatrix is swizzled with the mma format, so it +// currently should not be fused with any pointwise ops. This +// check is to protect against these cases. +// This would also not be needed once scatter swizzle ready, should +// just become a swizzle format check if we wanted to fuse ldmatrix +// with any op other than mma. +void validateLdMatrixOutput(TensorView* tv) { + const auto& out_uses = tv->fusion()->unordered_uses(tv); + if (out_uses.empty()) { + return; + } + // Could be relaxed + TORCH_INTERNAL_ASSERT(out_uses.size() == 1); + auto out_use = *(out_uses.begin()); + + if (out_use->isA()) { + validateLdMatrixOutput(out_use->output(0)->as()); + return; + } + + TORCH_INTERNAL_ASSERT( + out_use->isA(), + "validateLdMatrixOutput: currently only supports single mma use for ldmatrix", + out_use); +} + +// Checks that the memory ops are supported on the targeted GPU +void validateArchMemoryOp(UnaryOp* uop) { + auto uop_type = uop->getUnaryOpType(); + switch (uop_type) { + case UnaryOpType::LD_MATRIX: + case UnaryOpType::LD_MATRIXT: + validateMinimumArch(7, 5); + validateLdMatrixOutput(uop->out()->as()); + return; + case UnaryOpType::CP_ASYNC: + validateMinimumArch(8, 0); + return; + default: + return; + } +} + } // namespace //! Validate data format and GPU arch compatibility of scheduled @@ -875,11 +954,21 @@ void validateMma(Fusion* fusion) { case MmaOptions::MacroType::Volta_16_16_4: validateMinimumArch(7, 0); break; + case MmaOptions::MacroType::Turing_16_8_16: + validateMinimumArch(7, 5); + + // Check that operands come from ldmatrix, can be + // relaxed once swizzles can be labeled on iterdomains. + validateTuringInput(mma->inA()->as()); + validateTuringInput(mma->inB()->as()); default: TORCH_INTERNAL_ASSERT(false, "validate mma: unsupported macro"); break; } } + if (auto uop = dynamic_cast(expr)) { + validateArchMemoryOp(uop); + } } } diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index 3751cdea6bcf..9fb18cfa6f88 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -16,6 +16,9 @@ MmaBuilder::MmaBuilder( case MmaOptions::MacroType::Volta_16_16_4: option_.accumulator_stride = outer_stride * 4; break; + case MmaOptions::MacroType::Turing_16_8_16: + option_.accumulator_stride = outer_stride * 2; + break; default: TORCH_CHECK(false, "unsupported macro"); break; @@ -37,6 +40,33 @@ MmaOptions MmaBuilder::build() const { return option_; } +namespace { + +// Utility to get ldmatrix direction a mma layout and operand +UnaryOpType getLdMatrixType(MmaOptions options) { + bool transpose = false; + switch (options.macro) { + case MmaOptions::MacroType::Turing_16_8_16: + // Turing mma assumes TN as default + transpose = (options.operand == MmaOptions::Operand::A && + !isOperandTransposed(options)) || + (options.operand == MmaOptions::Operand::B && + isOperandTransposed(options)); + break; + default: + TORCH_INTERNAL_ASSERT(false, "unsupported op with ldmatrix"); + break; + } + + return transpose ? UnaryOpType::LD_MATRIXT : UnaryOpType::LD_MATRIX; +} + +} // namespace + +UnaryOpType MmaBuilder::ldMatrix() { + return getLdMatrixType(option_); +} + bool isVolta(MmaOptions::MacroType macro) { return macro == MmaOptions::MacroType::Volta_16_16_4; } @@ -54,6 +84,9 @@ int getOutputRegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 8; break; + case MmaOptions::MacroType::Turing_16_8_16: + return 4; + break; default: TORCH_INTERNAL_ASSERT(false, "unknown macro"); break; @@ -66,6 +99,9 @@ int getInputARegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 4; break; + case MmaOptions::MacroType::Turing_16_8_16: + return 8; + break; default: TORCH_INTERNAL_ASSERT(false, "unknown macro"); break; @@ -78,6 +114,8 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 4; break; + case MmaOptions::MacroType::Turing_16_8_16: + return 4; default: TORCH_INTERNAL_ASSERT(false, "unknown macro"); break; @@ -126,6 +164,9 @@ std::string toString(MmaOptions::MacroType mt) { case MmaOptions::MacroType::Volta_16_16_4: ss << "M16N16K4"; break; + case MmaOptions::MacroType::Turing_16_8_16: + ss << "M16N8K16"; + break; default: TORCH_INTERNAL_ASSERT(false, "undefined mma type"); break; diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index 5f42d41ded65..40d8aab410a6 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -57,7 +57,7 @@ struct MmaOptions { enum class MacroType { NoMMA = 0, Volta_16_16_4, - Turing_16_8_16, // place holder for turing/ampere mma + Turing_16_8_16, Ampere_16_8_8 // place holder for tf32 }; @@ -102,6 +102,7 @@ class TORCH_CUDA_CU_API MmaBuilder { MmaBuilder& layout(MmaOptions::MmaInputLayout layout); MmaBuilder& operand(MmaOptions::Operand a_or_b); MmaOptions build() const; + UnaryOpType ldMatrix(); private: MmaOptions option_; diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu new file mode 100644 index 000000000000..a4a9e6391e3a --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -0,0 +1,105 @@ +// Utility macro for this file +#define DEVICE_INLINE __device__ inline + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) + +namespace Turing { + +namespace util { + +// Special utility for ldmatrix and cp_async +DEVICE_INLINE unsigned toSmem(const void* ptr) { + unsigned smem_ptr; + + asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + : "=r"(smem_ptr) + : "l"(ptr)); + + return smem_ptr; +} + +} // namespace util + +DEVICE_INLINE void ldMatrix(Array<__half, 4>& out, void const* ptr) { + uint2& val = reinterpret_cast(out); + unsigned addr = util::toSmem(ptr); + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];" + : "=r"(val.x), "=r"(val.y) + : "r"(addr)); +} + +DEVICE_INLINE void ldMatrixT(Array<__half, 4>& out, void const* ptr) { + uint2& val = reinterpret_cast(out); + unsigned addr = util::toSmem(ptr); + asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];" + : "=r"(val.x), "=r"(val.y) + : "r"(addr)); +} + +DEVICE_INLINE void ldMatrix(Array<__half, 8>& out, void const* ptr) { + uint4& val = reinterpret_cast(out); + unsigned addr = util::toSmem(ptr); + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "r"(addr)); +} + +DEVICE_INLINE void ldMatrixT(Array<__half, 8>& out, void const* ptr) { + uint4& val = reinterpret_cast(out); + unsigned addr = util::toSmem(ptr); + asm volatile( + "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "r"(addr)); +} + +} // namespace Turing + +#endif // Arch 75 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + +namespace Ampere { + +// MMA instruction wrappers (sm_80+): + +namespace util { + +// Special utility for cp_async +DEVICE_INLINE unsigned toSmem(void* ptr) { + unsigned smem_ptr; + + asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + : "=r"(smem_ptr) + : "l"(ptr)); + + return smem_ptr; +} + +} // namespace util + +template +DEVICE_INLINE void cpAsync(Array* smem_ptr, void const* gmem_ptr) { + unsigned smem_addr = util::toSmem(&(smem_ptr->val[0])); + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(smem_addr), + "l"(gmem_ptr), + "n"(byte_size)); +} + +// TODO: Might have a different category of sync if we want to build out this: +DEVICE_INLINE void cpAsyncBarrier() { + asm volatile("cp.async.wait_all;"); +} + +} // namespace Ampere + +#endif // Arch 80 + +#undef DEVICE_INLINE \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu index f95978e84475..7d240f24a1dc 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu @@ -18,8 +18,8 @@ namespace Volta { namespace util { // MMA instruction wrappers (sm_70+): // The instruction wrappers below are quarter-warp macros, which currently -// nvfuser -// doesn't explicitly model. So they are currently only meant to be +// nvfuser doesn't explicitly model. +// So they are currently only meant to be // used as building blocks in warp level mma macros // 8x8x4 mma instruction, per quarter warp (8 threads), fp32 accumulate @@ -212,4 +212,78 @@ DEVICE_INLINE void initM16N16K4NT(Array* accumulator) { } // namespace Volta +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) + +namespace Turing { + +namespace util { +// MMA instruction wrappers (sm_75+): +DEVICE_INLINE void m16n8k16TN( + Array* C, + Array<__half, 8>* A, + Array<__half, 4>* B) { + unsigned const* _A = reinterpret_cast(A); + unsigned const* _B = reinterpret_cast(B); + unsigned* _C = reinterpret_cast(C); + const unsigned* _D = reinterpret_cast(C); + + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3]) + : "r"(_A[0]), + "r"(_A[1]), + "r"(_A[2]), + "r"(_A[3]), + "r"(_B[0]), + "r"(_B[1]), + "r"(_D[0]), + "r"(_D[1]), + "r"(_D[2]), + "r"(_D[3])); +} + +} // namespace util + +template +DEVICE_INLINE void M16N8K16TN( + Array* C, + Array<__half, 8>* A, + Array<__half, 4>* B) { + // TODO: in a follow up, + // lift this fused swizzle onto iterdomain + float* _C = reinterpret_cast(C); + + float C_data[4] = {_C[0], _C[1], _C[acc_stride], _C[acc_stride + 1]}; + + util::m16n8k16TN(reinterpret_cast*>(&C_data[0]), A, B); + + _C[0] = C_data[0]; + _C[1] = C_data[1]; + _C[acc_stride] = C_data[2]; + _C[acc_stride + 1] = C_data[3]; +} + +// MMA instruction wrappers (sm_75+): +DEVICE_INLINE void M16N8K8TN( + Array* C, + Array<__half, 4>* A, + Array<__half, 2>* B) { + unsigned const* _A = reinterpret_cast(A); + unsigned const* _B = reinterpret_cast(B); + unsigned* _C = reinterpret_cast(C); + + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%0,%1,%2,%3};\n" + : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3]) + : "r"(_A[0]), + "r"(_A[1]), + "r"(_B[0]), + "r"(_C[0]), + "r"(_C[1]), + "r"(_C[2]), + "r"(_C[3])); +} + +} // namespace Turing + +#endif // Arch 75 + #undef DEVICE_INLINE diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 875a9ea5ab15..6dfe7e0b0944 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -16,6 +16,31 @@ namespace { // Utility for mma dimension matching enum class MmaDimension { M = 0, N, K }; +// Utility for mma dimension matching before broadcast, +// assumes the innermost 2 dimensions are the mma +// operand dimensions, i.e. mnk. +IterDomain* getMmaOperandRootDimension2d( + TensorView* tv, + MmaOptions options, + MmaDimension mma_dimension) { + TORCH_INTERNAL_ASSERT(tv->getMaybeRFactorDomain().size() >= 2); + // NT : K,M x K,N -> K,M,N + // TT : M,K X K,N -> M,K,N + // TN : M,K X N,K -> M,N,K + int axis_id = mma_dimension == MmaDimension::K ? 1 : 0; + bool is_transposed = isOperandTransposed(options); + + // Decode the transpostion + if ((options.operand == MmaOptions::Operand::A && !is_transposed) || + (options.operand == MmaOptions::Operand::B && is_transposed)) { + axis_id = 1 - axis_id; + } + + int root_size = tv->getMaybeRFactorDomain().size(); + // Convert to index from right. + return tv->getMaybeRFactorDomain().at(root_size + axis_id - 2); +} + // Utility for mma dimension matching, assumes the innermost // 3 dimensions are the mma operand dimensions, i.e. mnk, but // not necessarily in this order. @@ -63,6 +88,23 @@ IterDomain* getMmaOperandRootDimension( if (isVolta(options.macro)) { return getMmaOperandRootDimension3d( tv, options.operand_layout, mma_dimension); + } else if (isTuring(options.macro)) { + // Volta mma swizzle requires the broadcast dimension to + // participate, which is not true in Turing+. So the two + // cases w/ or w/o the broadcast are supported here for + // mma pre-swizzle validation. + bool has_broadcast_or_reduction = std::any_of( + tv->getMaybeRFactorDomain().begin(), + tv->getMaybeRFactorDomain().end(), + [](IterDomain* id) { return id->isBroadcast() || id->isReduction(); }); + if (has_broadcast_or_reduction) { + TORCH_INTERNAL_ASSERT(tv->nDims() >= 3); + return getMmaOperandRootDimension3d( + tv, options.operand_layout, mma_dimension); + } else { + TORCH_INTERNAL_ASSERT(tv->nDims() >= 2); + return getMmaOperandRootDimension2d(tv, options, mma_dimension); + } } TORCH_INTERNAL_ASSERT(false, "unreachable"); return nullptr; @@ -178,6 +220,12 @@ void WarpMmaSwizzler::scheduleMmaWarpOutput( setWarpMapped(tv, 5); } break; + case MmaOptions::MacroType::Turing_16_8_16: + scheduleTuringM16N8K16MmaWarpOutput(tv, options); + if (tv->definition()->isA()) { + setWarpMapped(tv, 4); + } + break; default: TORCH_CHECK( false, "scheduleMmaWarp: unsupported mma option ", toString(macro)); @@ -193,6 +241,9 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) { case MmaOptions::MacroType::Volta_16_16_4: scheduleVoltaOperandRead(tv, options); break; + case MmaOptions::MacroType::Turing_16_8_16: + scheduleTuringOperandRead(tv, options); + break; default: TORCH_CHECK(false, "WarpMmaSwizzler: please specify macro"); break; @@ -233,6 +284,40 @@ void validateResultInnerMN(TensorView* tv, int m, int n) { tv->getMaybeRFactorDomain()[root_dim - 1], tv->axis(-1), n)); } +//! Performs checks on tv given to schedule ld matrix. +//! Currently only allowed ones are either: +//! 1. direct output of an ldmatrix op or +//! 2. direct output of a broadcast op following a ldmatrix op +//! Returns true if the tv is an immediate output of ldmatrix op +bool checkLdMatrixTv(TensorView* tv) { + // First check if tv is an ldmatrix output: + auto tv_def = tv->definition(); + TORCH_CHECK(tv_def != nullptr, "ldmatrix : invalid tv"); + auto tv_def_uop = dynamic_cast(tv_def); + bool is_immediate_output = true; + if (tv_def_uop == nullptr) { + // Only allow one broadcast in between tv and the ldmatrix op + TORCH_CHECK(tv_def->isA()); + tv_def = tv_def->input(0)->definition(); + TORCH_CHECK(tv_def != nullptr, "ldmatrix : invalid tv"); + tv_def_uop = dynamic_cast(tv_def); + is_immediate_output = false; + } + + TORCH_CHECK(tv_def_uop != nullptr, "ldmatrix : invalid op"); + TORCH_CHECK( + tv_def_uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || + tv_def_uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT, + "ldmatrix : invalid op type"); + TORCH_CHECK(tv->nDims() > 2); + TORCH_CHECK(!tv->axis(-1)->isBroadcast()); + TORCH_CHECK(!tv->axis(-1)->isReduction()); + TORCH_CHECK(!tv->axis(-2)->isBroadcast()); + TORCH_CHECK(!tv->axis(-2)->isReduction()); + + return is_immediate_output; +} + void scheduleVoltaA(TensorView* tv, MmaOptions options) { // Assumed: // [..., 16, 16 ,4] @@ -303,6 +388,101 @@ void scheduleVoltaB(TensorView* tv, MmaOptions options) { tv->axis(-2)->parallelize(ParallelType::TIDx); } +void scheduleLdMatrix(TensorView* tv, MmaOptions options) { + // Check if tv should use ldmatrix layout and + // if tv is immediate output of ldmatrix + bool is_immediate_output = checkLdMatrixTv(tv); + + // Decode transposition requirement for turing mma + bool transposed = options.operand == MmaOptions::Operand::A + ? !isOperandTransposed(options) + : isOperandTransposed(options); + // Check mma option is supported + TORCH_CHECK( + options.macro == MmaOptions::MacroType::Turing_16_8_16, + "scheduleLdMatrix: unknown macro for ldmatrix"); + + if (options.operand == MmaOptions::Operand::A) { + TORCH_INTERNAL_ASSERT(tv->nDims() >= 2); + // validation: + TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::M), + tv->axis(-2), + 16)); + TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::K), + tv->axis(-1), + 16)); + + //[16m, 16k] + tv->split(-2, 8); + tv->split(-1, 8); + + // -4 -3 -2 -1 + //[2o, 8o, 2i, 8i] + tv->reorder({{-4, -3}, {-3, -2}, {-2, -4}}); + + // -4 -3 -2 -1 + // [2i, 2o, 8o, 8i] + + if (transposed) { + tv->reorder({{-1, -2}, {-2, -1}}); + } + + tv->merge(-4); + tv->merge(-3); + // [warp, 8i/o] + + tv->axis(-2)->parallelize(ParallelType::TIDx); + } else if (options.operand == MmaOptions::Operand::B) { + // validation: + TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::N), + tv->axis(-2), + 8)); + TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::K), + tv->axis(-1), + 16)); + + if (transposed) { + // [8, 16] + tv->split(-2, 4); + + // [2i, 4i, 16] + tv->reorder({{-1, -2}, {-2, -1}}); + // [2i, 16, 4i] + + tv->merge(-3); + // [warp, 4i] + } else { + //[8, 16] + tv->split(-1, 4); + tv->split(-2, 2); + + // 0 1 2 3 4 + //[8, oo2,oi2,i4] + tv->reorder({{-4, -2}, {-2, -4}}); + + // 0 1 2 3 + //[oi2, oo2, 8,i4] + + tv->merge(-4); + tv->merge(-3); + // 0 1 + //[warp, i4] + } + + tv->axis(-2)->parallelize(ParallelType::TIDx); + } else { + TORCH_INTERNAL_ASSERT(false, "unreachable"); + } + + if (is_immediate_output) { + tv->axis(-1)->parallelize(ParallelType::Vectorize); + } +} + } // namespace void WarpMmaSwizzler::scheduleVoltaOperandRead( @@ -380,6 +560,52 @@ void WarpMmaSwizzler::scheduleVoltaM16N16K4Fp32Output( } } +void WarpMmaSwizzler::scheduleTuringOperandRead( + TensorView* tv, + MmaOptions options) { + scheduleLdMatrix(tv, options); + setWarpMapped(tv, 2); +} + +void WarpMmaSwizzler::scheduleTuringM16N8K16MmaWarpOutput( + TensorView* tv, + const MmaOptions& options) { + // Assume last 2 dims [M16, N8] or [M16, N8, R] + // Locate instruction m + bool is_reduction = tv->axis(-1)->isReduction(); + + // Make sure instruction tile size is correct. + if (is_reduction) { + validateInnerMNK(tv, options, 16, 8, 16); + } else { + validateResultInnerMN(tv, 16, 8); + } + + int m_pos = is_reduction ? -3 : -2; + + // m + // [16, 8 (,R)] + tv->split(m_pos, 8); + tv->split(m_pos + 1, 2); + + // m + // [2o, 8o, 4i, 2i (,R)] + tv->merge(m_pos - 1); + + // m + // [2o, Warp, 2i (,R)] + TORCH_CHECK(tv->definition() != nullptr); + + if (is_reduction && tv->definition()->isA()) { + // Set instruction loops for mma reduce + for (int pos : c10::irange(4)) { + tv->axis(-pos - 1)->parallelize(ParallelType::Mma); + } + } + + tv->axis(m_pos)->parallelize(ParallelType::TIDx); +} + namespace { bool isMmaInitLoop(const kir::Scope& loop_body) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h index 2ee1b4473277..ef51b64d7095 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h @@ -121,6 +121,12 @@ class TORCH_CUDA_CU_API WarpMmaSwizzler { TensorView* tv, const MmaOptions& options); + //! Swizzle implementations for Turing mma. + static void scheduleTuringOperandRead(TensorView* tv, MmaOptions options); + static void scheduleTuringM16N8K16MmaWarpOutput( + TensorView* tv, + const MmaOptions& options); + //! Utility to lock the transformed dimensions from further transforms. static void setWarpMapped(TensorView* tv, int number_of_dims); }; diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index dca0f1e85f1e..1e3f4621f5fd 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -795,7 +795,7 @@ WelfordResult TensorView::rFactor( return WelfordResult(producer_avg, producer_var, producer_n); } -TensorView* TensorView::cache_before() { +TensorView* TensorView::cache_before(UnaryOpType cache_op) { TORCH_INTERNAL_ASSERT( !container()->isA(), "Function invalid for kernel container."); @@ -872,7 +872,7 @@ TensorView* TensorView::cache_before() { ir_utils::replaceValInExpr(definition(), this, producer); // Expr* producer_uses = - IrBuilder::create(container(), UnaryOpType::Set, consumer, producer); + IrBuilder::create(container(), cache_op, consumer, producer); // definition_ is no longer valid // setDefinition(nullptr); @@ -931,7 +931,7 @@ TensorView* TensorView::cache_fork() { return new_output; } -TensorView* TensorView::cache_after() { +TensorView* TensorView::cache_after(UnaryOpType cache_op) { TORCH_INTERNAL_ASSERT( !container()->isA(), "Function invalid for kernel container."); @@ -997,7 +997,7 @@ TensorView* TensorView::cache_after() { } // Expr* consumer_definition = - IrBuilder::create(container(), UnaryOpType::Set, consumer, producer); + IrBuilder::create(container(), cache_op, consumer, producer); return consumer; } diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 0742cea74b75..fde2294eaea8 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -344,6 +344,12 @@ static const char* unary_op_type2string(UnaryOpType t) { return "tanh"; case UnaryOpType::Trunc: return "trunc"; + case UnaryOpType::CP_ASYNC: + return "cpasync"; + case UnaryOpType::LD_MATRIX: + return "ldmatrix"; + case UnaryOpType::LD_MATRIXT: + return "ldmatrixt"; default: TORCH_INTERNAL_ASSERT(false, "No string found for unary op type."); } diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index d84a7b26564c..b7bb52f5fc0e 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -152,7 +152,16 @@ enum class UnaryOpType { Trunc, // Might be a bitwise operator or boolean operator. - Not + Not, + + // Memory ops, + // TODO & Note: Should consider moving these into + // a separate/new IR node, maybe LDST? Also should be moved + // are most of the usage cases of UnaryOp::Set, + // which are indeed representing memory read writes. + CP_ASYNC, + LD_MATRIX, + LD_MATRIXT }; // Primarily for Not, which could be Not a boolean, or a bitwise not. From 2f0ae5e4e6471b7deeee0dbabb3510f4976a94a3 Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 22 Mar 2022 10:41:07 -0700 Subject: [PATCH 14/57] fix rebase --- torch/csrc/jit/codegen/cuda/codegen.cpp | 4 +- .../jit/codegen/cuda/lower_validation.cpp | 1 + torch/csrc/jit/codegen/cuda/runtime/memory.cu | 12 ++--- .../jit/codegen/cuda/runtime/tensorcore.cu | 44 +++++++------------ 4 files changed, 26 insertions(+), 35 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index f58b36de8daf..47ac6035d8bc 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -461,8 +461,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { std::string genVectorPointer(Val* val, DataType dtype, int vec_size) { std::stringstream ss; - ss << "reinterpret_cast*>(&" - << gen(val) << ")"; + ss << "reinterpret_cast*>(&" << gen(val) << ")"; return ss.str(); } diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 7915b6a9f46d..683ec6840dc5 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -964,6 +964,7 @@ void validateMma(Fusion* fusion) { // relaxed once swizzles can be labeled on iterdomains. validateTuringInput(mma->inA()->as()); validateTuringInput(mma->inB()->as()); + break; default: TORCH_INTERNAL_ASSERT(false, "validate mma: unsupported macro"); break; diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index a4a9e6391e3a..4eadbf696e1a 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -20,7 +20,7 @@ DEVICE_INLINE unsigned toSmem(const void* ptr) { } // namespace util -DEVICE_INLINE void ldMatrix(Array<__half, 4>& out, void const* ptr) { +DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) { uint2& val = reinterpret_cast(out); unsigned addr = util::toSmem(ptr); asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];" @@ -28,7 +28,7 @@ DEVICE_INLINE void ldMatrix(Array<__half, 4>& out, void const* ptr) { : "r"(addr)); } -DEVICE_INLINE void ldMatrixT(Array<__half, 4>& out, void const* ptr) { +DEVICE_INLINE void ldMatrixT(Array<__half, 4, 4>& out, void const* ptr) { uint2& val = reinterpret_cast(out); unsigned addr = util::toSmem(ptr); asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];" @@ -36,7 +36,7 @@ DEVICE_INLINE void ldMatrixT(Array<__half, 4>& out, void const* ptr) { : "r"(addr)); } -DEVICE_INLINE void ldMatrix(Array<__half, 8>& out, void const* ptr) { +DEVICE_INLINE void ldMatrix(Array<__half, 8, 8>& out, void const* ptr) { uint4& val = reinterpret_cast(out); unsigned addr = util::toSmem(ptr); asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" @@ -44,7 +44,7 @@ DEVICE_INLINE void ldMatrix(Array<__half, 8>& out, void const* ptr) { : "r"(addr)); } -DEVICE_INLINE void ldMatrixT(Array<__half, 8>& out, void const* ptr) { +DEVICE_INLINE void ldMatrixT(Array<__half, 8, 8>& out, void const* ptr) { uint4& val = reinterpret_cast(out); unsigned addr = util::toSmem(ptr); asm volatile( @@ -79,7 +79,9 @@ DEVICE_INLINE unsigned toSmem(void* ptr) { } // namespace util template -DEVICE_INLINE void cpAsync(Array* smem_ptr, void const* gmem_ptr) { +DEVICE_INLINE void cpAsync( + Array* smem_ptr, + void const* gmem_ptr) { unsigned smem_addr = util::toSmem(&(smem_ptr->val[0])); constexpr int byte_size = sizeof(dtype) * len; diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu index 7d240f24a1dc..11a6a1202d1d 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu @@ -219,9 +219,9 @@ namespace Turing { namespace util { // MMA instruction wrappers (sm_75+): DEVICE_INLINE void m16n8k16TN( - Array* C, - Array<__half, 8>* A, - Array<__half, 4>* B) { + Array* C, + Array<__half, 8, 8>* A, + Array<__half, 4, 4>* B) { unsigned const* _A = reinterpret_cast(A); unsigned const* _B = reinterpret_cast(B); unsigned* _C = reinterpret_cast(C); @@ -243,18 +243,26 @@ DEVICE_INLINE void m16n8k16TN( } // namespace util +template +DEVICE_INLINE void initM16N8K16TN(Array* accumulator) { + float* _C = reinterpret_cast(accumulator); + _C[0] = 0; + _C[1] = 0; + _C[acc_stride] = 0; + _C[acc_stride + 1] = 0; +} + template DEVICE_INLINE void M16N8K16TN( - Array* C, - Array<__half, 8>* A, - Array<__half, 4>* B) { + Array* C, + Array<__half, 8, 8>* A, + Array<__half, 4, 4>* B) { // TODO: in a follow up, // lift this fused swizzle onto iterdomain float* _C = reinterpret_cast(C); - float C_data[4] = {_C[0], _C[1], _C[acc_stride], _C[acc_stride + 1]}; - util::m16n8k16TN(reinterpret_cast*>(&C_data[0]), A, B); + util::m16n8k16TN(reinterpret_cast*>(&C_data[0]), A, B); _C[0] = C_data[0]; _C[1] = C_data[1]; @@ -262,26 +270,6 @@ DEVICE_INLINE void M16N8K16TN( _C[acc_stride + 1] = C_data[3]; } -// MMA instruction wrappers (sm_75+): -DEVICE_INLINE void M16N8K8TN( - Array* C, - Array<__half, 4>* A, - Array<__half, 2>* B) { - unsigned const* _A = reinterpret_cast(A); - unsigned const* _B = reinterpret_cast(B); - unsigned* _C = reinterpret_cast(C); - - asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%0,%1,%2,%3};\n" - : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3]) - : "r"(_A[0]), - "r"(_A[1]), - "r"(_B[0]), - "r"(_C[0]), - "r"(_C[1]), - "r"(_C[2]), - "r"(_C[3])); -} - } // namespace Turing #endif // Arch 75 From b96462a6710da8b557f3f67118d8de9b65738d15 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 23 Mar 2022 15:27:54 -0700 Subject: [PATCH 15/57] more rebase fix --- .../jit/codegen/cuda/lower_insert_syncs.cpp | 2 +- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 18 +++++++++++++++++- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index 74bb4c9d9e0c..be61f80b25d6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -525,7 +525,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { sync_bitmap, maybe_alloc->buffer()); } else { bool need_gmem_sync = gmem_sync_after_.count(expr); - auto sync_expr = IrBuilder::create( + sync_expr = IrBuilder::create( false, // is not war sync need_gmem_sync); } diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index ff9fc411db6b..b5559f770845 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -120,6 +120,15 @@ void UnrollPass::handle(Expr* expr) { return false; }; + auto is_cpasync_op = [](Expr* expr) { + if (auto uop = dynamic_cast(expr)) { + if (uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { + return true; + } + } + return false; + }; + bool is_ld_matrix = false; bool is_ld_matrix_producer = false; if (is_ld_matrix_op(expr)) { @@ -149,12 +158,19 @@ void UnrollPass::handle(Expr* expr) { val->as()->getMemoryType() == MemoryType::Global; })) { for (auto out_tv : ir_utils::filterByType(expr->outputs())) { + if (is_cpasync_op(out_tv->definition())) { + // cp async op has initialization outside of the predicate protected + // boundary so no need to ensure predicate removal. + break; + } + for (auto use : expr->fusion()->unordered_uses(out_tv)) { if (is_ld_matrix_op(use)) { TORCH_INTERNAL_ASSERT( GpuLower::current()->predicateElimination().canOmitPredicate( expr), - "Vectorize: producers of ldmatrix cannot be predicated except for global loads"); + "Vectorize: producers of ldmatrix cannot be predicated except for global loads:", + expr); is_ld_matrix_producer = true; break; } diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 4eadbf696e1a..4062d324f563 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -82,7 +82,7 @@ template DEVICE_INLINE void cpAsync( Array* smem_ptr, void const* gmem_ptr) { - unsigned smem_addr = util::toSmem(&(smem_ptr->val[0])); + unsigned smem_addr = util::toSmem(&(smem_ptr->array[0])); constexpr int byte_size = sizeof(dtype) * len; static_assert( From 5347fdaae0257b7d3b3830a73deb2b459b05f9d2 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 28 Mar 2022 09:54:13 -0700 Subject: [PATCH 16/57] test comment --- test/cpp/jit/test_gpu_tensorcore.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/test/cpp/jit/test_gpu_tensorcore.cpp b/test/cpp/jit/test_gpu_tensorcore.cpp index 50b4269af882..7cbb5809170d 100644 --- a/test/cpp/jit/test_gpu_tensorcore.cpp +++ b/test/cpp/jit/test_gpu_tensorcore.cpp @@ -891,6 +891,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } +// MMA unit test on turing TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); @@ -959,6 +960,7 @@ TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } +// MMA unit test on turing TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); @@ -1031,6 +1033,7 @@ TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } +// MMA unit test on turing TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); @@ -1106,7 +1109,8 @@ TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionTuringGemmTN_CUDA) { +// Matmul test on turing +TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); Fusion fusion; @@ -1254,7 +1258,8 @@ TEST_F(NVFuserTest, FusionTuringGemmTN_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } -TEST_F(NVFuserTest, FusionTuringGemmTT_CUDA) { +// Matmul test on turing +TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); Fusion fusion; @@ -1403,7 +1408,8 @@ TEST_F(NVFuserTest, FusionTuringGemmTT_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } -TEST_F(NVFuserTest, FusionTuringGemmNT_CUDA) { +// Matmul test on turing +TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); Fusion fusion; @@ -1553,7 +1559,8 @@ TEST_F(NVFuserTest, FusionTuringGemmNT_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } -TEST_F(NVFuserTest, FusionGemmGemmTuring_CUDA) { +// Matmul-Matmul fusion test on turing +TEST_F(NVFuserTest, FusionMatmulMatmulTuring_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); Fusion fusion; @@ -1827,7 +1834,9 @@ TEST_F(NVFuserTest, FusionGemmGemmTuring_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); } -TEST_F(NVFuserTest, FusionGemmSoftmaxGemmTuring_CUDA) { +// Simplified Matmul-Softmax-Matmul test on turing +// (To be extended in follow ups) +TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulTuring_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); Fusion fusion; @@ -2204,6 +2213,7 @@ TEST_F(NVFuserTest, FusionGemmSoftmaxGemmTuring_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001)); } +// Matmul test on ampere, using ampere memory ops TEST_F(NVFuserTest, FusionAmpereGemmTN_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); From 4834995fb17cd5b30f4fca032d885105c38ec9be Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 28 Mar 2022 10:04:52 -0700 Subject: [PATCH 17/57] submodule --- third_party/XNNPACK | 2 +- third_party/fbgemm | 2 +- third_party/flatbuffers | 2 +- third_party/ideep | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/XNNPACK b/third_party/XNNPACK index 79cd5f9e18ad..ae108ef49aa5 160000 --- a/third_party/XNNPACK +++ b/third_party/XNNPACK @@ -1 +1 @@ -Subproject commit 79cd5f9e18ad0925ac9a050b00ea5a36230072db +Subproject commit ae108ef49aa5623b896fc93d4298c49d1750d9ba diff --git a/third_party/fbgemm b/third_party/fbgemm index 7588d9d80482..1ddff63cd3a9 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 7588d9d804826b428fc0e4fd418e9cc3f7a72e52 +Subproject commit 1ddff63cd3a99bdd8f52e8147dbfe723522d2f48 diff --git a/third_party/flatbuffers b/third_party/flatbuffers index 697147a2e686..d0cede9c90c5 160000 --- a/third_party/flatbuffers +++ b/third_party/flatbuffers @@ -1 +1 @@ -Subproject commit 697147a2e686486424b9d15fc3e1612586a60f97 +Subproject commit d0cede9c90c5257537c293517a21376408b549fa diff --git a/third_party/ideep b/third_party/ideep index 82aac435b5ec..4a56ab2c3f61 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 82aac435b5ecfec0855d0d72b84aee3ed0e72813 +Subproject commit 4a56ab2c3f61c44e0f8ea241beeb732b7d70dc5b From 9e18e04ead98191f413967250912570bbf48d203 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 28 Mar 2022 10:54:59 -0700 Subject: [PATCH 18/57] cleanup and comments --- torch/csrc/jit/codegen/cuda/codegen.cpp | 11 ++++- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 7 +++ torch/csrc/jit/codegen/cuda/kernel_ir.h | 1 + .../jit/codegen/cuda/lower_double_buffer.cpp | 6 +++ .../jit/codegen/cuda/lower_double_buffer.h | 13 ++++++ .../jit/codegen/cuda/lower_insert_syncs.cpp | 14 ++++++ .../csrc/jit/codegen/cuda/lower_predicate.cpp | 44 ++++++++++++++----- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 6 ++- .../jit/codegen/cuda/lower_validation.cpp | 4 +- 9 files changed, 90 insertions(+), 16 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 47ac6035d8bc..60e58feb40b8 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -458,6 +458,12 @@ class CudaKernelGenerator : private OptOutConstDispatch { TORCH_INTERNAL_ASSERT(false, "Unreachable"); } + //! Utility for generating vectorized pointer access in ldsm and + //! cpasync. + //! TODO: this access pattern as is could be merged with exisiting + //! vectorization handling logic but this path will be updated in + //! follow ups to optimize the generated assembly so keeping them + //! separate path for now. std::string genVectorPointer(Val* val, DataType dtype, int vec_size) { std::stringstream ss; @@ -946,7 +952,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (isVolta(options.macro)) { ss << toString(options.operand_layout); } else if (isTuring(options.macro) || isAmpere(options.macro)) { - // mma's in turing and ampere TN only. + // mma's in turing and ampere TN only, transpose is handled either + // via ldmatrix for fp16 or explicitly for other types. ss << "TN"; } // TODO: additional parameter could be removed by swizzling iterdomain @@ -1857,6 +1864,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { } void handle(const kir::BlockSync* sync) final { + // Issue a barrier to sync with async loads, + // see [Note: dma sync with gmem loads] if (sync->syncGmem()) { indent() << "Ampere::cpAsyncBarrier();\n"; } diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index e9ab0cdfe4a0..b252ee146c04 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1095,6 +1095,13 @@ void IterDomain::parallelize(ParallelType t) { } if (isMmaSwizzled()) { + // Mma swizzled axes represent data representation within a warp + // so only allow updates that keep the parallelization within + // a warp. + // Note && TODO: this check is actually used to allow indexing path + // to make copies of the iterdomains. We might eventually just want + // to lock these parallel types and not allowing any changes once + // they are swizzled. TORCH_CHECK( t == ParallelType::Vectorize || t == ParallelType::TIDx || t == ParallelType::Serial, diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index e52cf18e3de0..6b069ae0c18a 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -261,6 +261,7 @@ class TORCH_CUDA_CU_API BlockSync final : public Expr { // TODO: war_sync_ is only used for testing/validation purposes. bool war_sync_ = false; + //! [Note: Dma sync with gmem loads] //! Indicates if this block sync also synchronizes with asynchronous //! copy ops from global mem to shared mem. //! Currently making this a parameter in BlockSync so that we always diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index 5c739f625c44..61d4ca379b75 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -127,6 +127,10 @@ class DoubleBufferFusionInspector : private IterVisitor { TORCH_INTERNAL_ASSERT( tv->definition(), "Fusion input shouldn't be double buffered.", tv); + + // Record that there is a double buffered tensor that is a consumer + // of async gmem load. The double buffer sync will have to sync + // with the async load as well. if (auto uop = dynamic_cast(tv->definition())) { if (uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { db_info_.setAsyncGmemDoubleBuffer(); @@ -418,6 +422,8 @@ class DoubleBufferInserter : private kir::ExprMutator { // RAW sync is not inserted for double buffered tensors. The only // exception is the prologue load. if (write_to_smem) { + // Here the initial sync before entering double buffer loop is + // inserted. It will also need to sync with async gmem loads. auto sync = IrBuilder::create( false, GpuLower::current()->doubleBufferInfo().hasAsyncGmemDoubleBuffer()); diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h index 731d20eca2b2..d94184141e76 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h @@ -128,10 +128,20 @@ class TORCH_CUDA_CU_API DoubleBufferInfo { Val* getOriginalAllocSize(const TensorView* tv); + //! Sets the async gmem load flag indicating that + //! a cpasync load has been double buffered. void setAsyncGmemDoubleBuffer() { has_async_gmem_double_buffer_ = true; } + //! Check if any double buffered tensor is from + //! async gmem load. The current analysis is + //! simplistic to begin with, meaning we always + //! sync with the gmem load if we have any + //! sharedmem double buffered tensor is async + //! loaded. + //! TODO: will need to extend to support more complex + //! pipeline patterns in follow ups. bool hasAsyncGmemDoubleBuffer() { return has_async_gmem_double_buffer_; } @@ -142,6 +152,9 @@ class TORCH_CUDA_CU_API DoubleBufferInfo { private: //! Keeps track of information for lowering double buffered tensors std::unordered_map map_; + + //! Tracks if any double buffered tensor is consumer of async + //! gmem load. bool has_async_gmem_double_buffer_ = false; }; diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index be61f80b25d6..f3d62f41255f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -134,6 +134,17 @@ class WarSyncInserter : private kir::ExprMutator { for (const auto& entry : lower_alloc_info_map) { alloc_map_.insert(entry.first); } + + // Currently the RAW sync is removed in the double buffer loop if + // a shared mem tensor is double buffered. In this case the sync + // inserted by war pass is actually both war and raw together. + // This part of the logic extension is for the case when a sharedmem + // tensor is double buffered and async loaded, the double buffer sync + // in this case need to also sync with the async loads. + // Note & TODO: + // might at some point want to separate out as another category like + // "double buffer sync" if we need to keep extending the + // war sync path but actually to prevent raw hazard. need_to_sync_gmem_ = GpuLower::current()->doubleBufferInfo().hasAsyncGmemDoubleBuffer(); kir::ExprMutator::traverseAndInsert(exprs); @@ -336,6 +347,7 @@ class WarSyncInserter : private kir::ExprMutator { // for WAR insertion. std::unordered_map> smem_allocations_; + //! Tracks if the inserted syncs need to also sync with the gmem async loads. bool need_to_sync_gmem_ = false; }; @@ -670,6 +682,8 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { auto last_smem_writes = isModifiedSharedMemory(smem, expr->inputs()); + // RAW sync insertion in the case where the load isn't double buffered. + // This sync is needed when any of the producer writes is async. bool need_gmem_sync = std::any_of( last_smem_writes.begin(), last_smem_writes.end(), [](Expr* expr) { if (auto uop = dynamic_cast(expr)) { diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 6a1224573146..2d3373234d2f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -50,7 +50,21 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { setWritePredicate(expr, conditional); } - if (isPredicatedGmemToSharedInit(expr)) { + // This is special handling for async copy from gmem to smem. + // The actual copy part is async to the gpu threads so we + // cannot have both the copy and the initialization as we + // initialize for other ops. This current approach inverts + // the predicate to make sure that the init and copy paths are + // never on at the same time, but as a side effect this predicate + // can only guard the gmem access and predicate removal pass + // will assert that the usage of cpasync is limited to cases + // where the smem consumer predicate is not needed. + // + // TODO: in a follow up we need to extend the predicate + // infrastructure to generate predicate for both gmem + // and smem, and the predicate removal will need to + // be extended as well for the perf critical regions. + if (isPredicatedInitForCpAsync(expr)) { invertPredicateForGmemToSharedMemInitialize(expr); } @@ -63,7 +77,9 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { expr->predicate()->setValue(invert->as()); } - bool isPredicatedGmemToSharedInit(Expr* expr) { + // Detect if this expr is an initialization for vectorized + // cp asyc. + bool isPredicatedInitForCpAsync(Expr* expr) { if (!expr->predicate() || expr->predicate()->predicate_type() != PredicateType::Vectorize) { return false; @@ -84,12 +100,11 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { TORCH_INTERNAL_ASSERT( fuser_expr != nullptr, "kir tv expr without corresponding fuser expr defining the output"); - if (auto in_fuser_tv = - dynamic_cast(fuser_expr->input(0))) { - return in_fuser_tv->getMemoryType() == MemoryType::Global && - out_tv->view()->getMemoryType() == MemoryType::Shared && - kir_uop->in()->isConstScalar(); + auto fuser_uop = dynamic_cast(fuser_expr); + if (!fuser_uop) { + return false; } + return fuser_uop->getUnaryOpType() == UnaryOpType::CP_ASYNC; } } } @@ -215,6 +230,10 @@ class PredicateAnalyzer : public OptOutDispatch { return true; } + // This is initial step to gradually remove predicates around + // sharedmem access in suitable situations. + // Using an additional variable to track the predicate-on reasons + // when the predicate around shared mem cannot be removed. bool needs_sharedmem_addr_pred = false; if (producer->getMemoryType() == MemoryType::Shared || consumer->getMemoryType() == MemoryType::Shared) { @@ -229,18 +248,19 @@ class PredicateAnalyzer : public OptOutDispatch { } } - // TODO: (Address in a follow up) - // predicate removal with init breaks unroll and unswitch, eg. as in - // issue 1133 disabling this usage for now + // TODO: (Enable in a follow up) + // smem predicate removal with init would break unroll and unswitch, + // eg. as in issue 1133, so disabling this removal pattern for now. if (id->getParallelType() == ParallelType::Unroll || id->getParallelType() == ParallelType::Unswitch) { needs_sharedmem_addr_pred = true; } - // TODO: (Address in a follow up) + // TODO: (Enable in a follow up) // This cannot yet be removed since smem initialization needs to be // handled specially, e.g. as in smem_reduce test. Will be able to - // lift this one once the generic pred removal pass is ready. + // lift this one once the generic pred removal pass with fusion + // traversal is ready. auto consumer_def = consumer->definition(); if (consumer_def->isA() || consumer_def->isA()) { diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index b5559f770845..e59274f72f1a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -110,6 +110,7 @@ void UnrollPass::handle(Expr* expr) { // Vectorized expressions should never use inline predicates kir::Predicate* pred = nullptr; + // Check if the given op is an ldmatrix auto is_ld_matrix_op = [](Expr* expr) { if (auto uop = dynamic_cast(expr)) { if (uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || @@ -120,6 +121,7 @@ void UnrollPass::handle(Expr* expr) { return false; }; + // Check if the given op is an async copy auto is_cpasync_op = [](Expr* expr) { if (auto uop = dynamic_cast(expr)) { if (uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { @@ -132,10 +134,10 @@ void UnrollPass::handle(Expr* expr) { bool is_ld_matrix = false; bool is_ld_matrix_producer = false; if (is_ld_matrix_op(expr)) { - // Note & TODO: currently not support predicated ldmatrix since it'd + // Note & TODO: currently cannot support predicated ldmatrix since it'd // need to make sure all threads in a warp evaluate the predicate // to the same value. - // Currently asserting that the predicate for ldmatrix and its producer + // For now asserting that the predicate for ldmatrix and its producer // can be omitted. Would need to build out support for warp level // predicates if we want vectorization predicates to be active on // these ops. diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 95ff08068772..bc1289dd4d98 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -925,7 +925,9 @@ void validateLdMatrixOutput(TensorView* tv) { if (out_uses.empty()) { return; } - // Could be relaxed + // TODO: restricting to single use pipelines for now which + // is true to matmul mainloop. This Could be relaxed to + // support more complex mma usage. TORCH_INTERNAL_ASSERT(out_uses.size() == 1); auto out_use = *(out_uses.begin()); From 1fb32ed536e0c6779a797c490593fd2321172770 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 28 Mar 2022 13:54:12 -0700 Subject: [PATCH 19/57] minor fix --- .../csrc/jit/codegen/cuda/lower_predicate.cpp | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 2d3373234d2f..cddc957c50a1 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -78,10 +78,9 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { } // Detect if this expr is an initialization for vectorized - // cp asyc. + // cp asyc with predicates. bool isPredicatedInitForCpAsync(Expr* expr) { - if (!expr->predicate() || - expr->predicate()->predicate_type() != PredicateType::Vectorize) { + if (expr->predicate() == nullptr) { return false; } if (auto ite = dynamic_cast(expr)) { @@ -95,16 +94,22 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { // Checking if this expr writes scalar to shared mem tensor, // while fuser expr read from gmem to shared mem. if (auto kir_uop = dynamic_cast(inner_expr)) { + // Initialization ops assumed to have scalar input. + if (!kir_uop->in()->isScalar()) { + continue; + } if (auto out_tv = dynamic_cast(kir_uop->out())) { auto fuser_expr = out_tv->view()->definition(); - TORCH_INTERNAL_ASSERT( - fuser_expr != nullptr, - "kir tv expr without corresponding fuser expr defining the output"); - auto fuser_uop = dynamic_cast(fuser_expr); - if (!fuser_uop) { - return false; + if (fuser_expr == nullptr) { + // This shouldn't hit but added this skip here in case + // any future kir ops have this behavior. + continue; + } + if (auto fuser_uop = dynamic_cast(fuser_expr)) { + if (fuser_uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { + return true; + } } - return fuser_uop->getUnaryOpType() == UnaryOpType::CP_ASYNC; } } } From 43d0e72dfd38d6747324b974cf04b2caf0c56cc1 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 28 Mar 2022 15:05:02 -0700 Subject: [PATCH 20/57] test cleanup --- test/cpp/jit/test_gpu_tensorcore.cpp | 41 +++++++++++++--------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/test/cpp/jit/test_gpu_tensorcore.cpp b/test/cpp/jit/test_gpu_tensorcore.cpp index 7cbb5809170d..084d555b94d1 100644 --- a/test/cpp/jit/test_gpu_tensorcore.cpp +++ b/test/cpp/jit/test_gpu_tensorcore.cpp @@ -947,12 +947,12 @@ TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); auto t0 = at::randn({16, 16}, options); auto t1 = at::randn({8, 16}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); @@ -1020,12 +1020,12 @@ TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); auto t0 = at::randn({16, 16}, options); auto t1 = at::randn({16, 8}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion,{t0,t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -1096,12 +1096,12 @@ TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); auto t0 = at::randn({16, 16}, options); auto t1 = at::randn({16, 8}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0,t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -1249,8 +1249,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { auto t1 = at::randn({N, K}, options); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {t0,t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); @@ -1399,8 +1398,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { auto t1 = at::randn({K, N}, options); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {t0,t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -1550,8 +1548,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { auto t1 = at::randn({K, N}, options); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {t0,t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -1821,15 +1818,15 @@ TEST_F(NVFuserTest, FusionMatmulMatmulTuring_CUDA) { auto t1 = at::randn({K2, K1}, options); auto t2 = at::randn({N, K2}, options); - FusionExecutor fe; - fe.compileFusion(&fusion); - - auto cg_outputs = fe.runFusion({t0, t1, t2}); - auto tref = t0.to(at::kFloat) .matmul(t1.t().to(at::kFloat)) .matmul(t2.t().to(at::kFloat)); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0,t1,t2}); + + auto cg_outputs = fe.runFusion({t0, t1, t2}); + // relaxed check for now, err accumulation is significant. TORCH_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); } @@ -2202,7 +2199,7 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulTuring_CUDA) { auto t2 = at::randn({N2, K2}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0,t1,t2}); auto cg_outputs = fe.runFusion({t0, t1, t2}); @@ -2343,7 +2340,7 @@ TEST_F(NVFuserTest, FusionAmpereGemmTN_CUDA) { auto t1 = at::randn({N, K}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0,t1}); auto cg_outputs = fe.runFusion({t0, t1}); From c1f437471627ae50e4d3b87514ca20690815814f Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 29 Mar 2022 11:53:58 -0700 Subject: [PATCH 21/57] format --- test/cpp/jit/test_gpu_tensorcore.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/cpp/jit/test_gpu_tensorcore.cpp b/test/cpp/jit/test_gpu_tensorcore.cpp index 084d555b94d1..152e5b6bc28c 100644 --- a/test/cpp/jit/test_gpu_tensorcore.cpp +++ b/test/cpp/jit/test_gpu_tensorcore.cpp @@ -950,7 +950,7 @@ TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); auto t0 = at::randn({16, 16}, options); auto t1 = at::randn({8, 16}, options); - + FusionExecutor fe; fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); @@ -1023,9 +1023,9 @@ TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); auto t0 = at::randn({16, 16}, options); auto t1 = at::randn({16, 8}, options); - + FusionExecutor fe; - fe.compileFusion(&fusion,{t0,t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -1099,9 +1099,9 @@ TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); auto t0 = at::randn({16, 16}, options); auto t1 = at::randn({16, 8}, options); - + FusionExecutor fe; - fe.compileFusion(&fusion, {t0,t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -1249,7 +1249,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { auto t1 = at::randn({N, K}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0,t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); @@ -1398,7 +1398,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { auto t1 = at::randn({K, N}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0,t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -1548,7 +1548,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { auto t1 = at::randn({K, N}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0,t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -1823,7 +1823,7 @@ TEST_F(NVFuserTest, FusionMatmulMatmulTuring_CUDA) { .matmul(t2.t().to(at::kFloat)); FusionExecutor fe; - fe.compileFusion(&fusion, {t0,t1,t2}); + fe.compileFusion(&fusion, {t0, t1, t2}); auto cg_outputs = fe.runFusion({t0, t1, t2}); @@ -1831,7 +1831,7 @@ TEST_F(NVFuserTest, FusionMatmulMatmulTuring_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); } -// Simplified Matmul-Softmax-Matmul test on turing +// Simplified Matmul-Softmax-Matmul test on turing // (To be extended in follow ups) TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulTuring_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); @@ -2199,7 +2199,7 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulTuring_CUDA) { auto t2 = at::randn({N2, K2}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0,t1,t2}); + fe.compileFusion(&fusion, {t0, t1, t2}); auto cg_outputs = fe.runFusion({t0, t1, t2}); @@ -2340,7 +2340,7 @@ TEST_F(NVFuserTest, FusionAmpereGemmTN_CUDA) { auto t1 = at::randn({N, K}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0,t1}); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); From 7c91a0a6d77c9533c4806b5ad60ccfef2d02e53f Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 29 Mar 2022 11:57:10 -0700 Subject: [PATCH 22/57] newline --- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 4062d324f563..26157e89a2b1 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -104,4 +104,4 @@ DEVICE_INLINE void cpAsyncBarrier() { #endif // Arch 80 -#undef DEVICE_INLINE \ No newline at end of file +#undef DEVICE_INLINE From f4c6f12b326e5e155d5e3cf7482d1fa1d34a30dc Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 25 Apr 2022 10:38:19 -0700 Subject: [PATCH 23/57] move all implementation to Ampere space --- .../jit/codegen/cuda/lower_validation.cpp | 4 +- torch/csrc/jit/codegen/cuda/mma_type.cpp | 16 ++-- torch/csrc/jit/codegen/cuda/mma_type.h | 2 +- .../jit/codegen/cuda/runtime/tensorcore.cu | 4 +- .../jit/codegen/cuda/scheduler/mma_utils.cpp | 8 +- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 76 +++++++++---------- 6 files changed, 55 insertions(+), 55 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index e5aa1ecb626e..9c619de12da4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -1026,8 +1026,8 @@ void validateMma(Fusion* fusion) { case MmaOptions::MacroType::Volta_16_16_4: validateMinimumArch(7, 0); break; - case MmaOptions::MacroType::Turing_16_8_16: - validateMinimumArch(7, 5); + case MmaOptions::MacroType::Ampere_16_8_16: + validateMinimumArch(8, 0); // Check that operands come from ldmatrix, can be // relaxed once swizzles can be labeled on iterdomains. diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index 9fb18cfa6f88..b0ba8fed846b 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -16,7 +16,7 @@ MmaBuilder::MmaBuilder( case MmaOptions::MacroType::Volta_16_16_4: option_.accumulator_stride = outer_stride * 4; break; - case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Ampere_16_8_16: option_.accumulator_stride = outer_stride * 2; break; default: @@ -46,7 +46,7 @@ namespace { UnaryOpType getLdMatrixType(MmaOptions options) { bool transpose = false; switch (options.macro) { - case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Ampere_16_8_16: // Turing mma assumes TN as default transpose = (options.operand == MmaOptions::Operand::A && !isOperandTransposed(options)) || @@ -72,11 +72,11 @@ bool isVolta(MmaOptions::MacroType macro) { } bool isTuring(MmaOptions::MacroType macro) { - return macro == MmaOptions::MacroType::Turing_16_8_16; + return false; } bool isAmpere(MmaOptions::MacroType macro) { - return false; + return macro == MmaOptions::MacroType::Ampere_16_8_16; } int getOutputRegisterSize(MmaOptions::MacroType macro) { @@ -84,7 +84,7 @@ int getOutputRegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 8; break; - case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Ampere_16_8_16: return 4; break; default: @@ -99,7 +99,7 @@ int getInputARegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 4; break; - case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Ampere_16_8_16: return 8; break; default: @@ -114,7 +114,7 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 4; break; - case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Ampere_16_8_16: return 4; default: TORCH_INTERNAL_ASSERT(false, "unknown macro"); @@ -164,7 +164,7 @@ std::string toString(MmaOptions::MacroType mt) { case MmaOptions::MacroType::Volta_16_16_4: ss << "M16N16K4"; break; - case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Ampere_16_8_16: ss << "M16N8K16"; break; default: diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index 40d8aab410a6..f762d683b3a0 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -57,7 +57,7 @@ struct MmaOptions { enum class MacroType { NoMMA = 0, Volta_16_16_4, - Turing_16_8_16, + Ampere_16_8_16, Ampere_16_8_8 // place holder for tf32 }; diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu index 11a6a1202d1d..01d27b343947 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu @@ -214,7 +214,7 @@ DEVICE_INLINE void initM16N16K4NT(Array* accumulator) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) -namespace Turing { +namespace Ampere { namespace util { // MMA instruction wrappers (sm_75+): @@ -270,7 +270,7 @@ DEVICE_INLINE void M16N8K16TN( _C[acc_stride + 1] = C_data[3]; } -} // namespace Turing +} // namespace Ampere #endif // Arch 75 diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 6dfe7e0b0944..fe50d1e3b41c 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -88,7 +88,7 @@ IterDomain* getMmaOperandRootDimension( if (isVolta(options.macro)) { return getMmaOperandRootDimension3d( tv, options.operand_layout, mma_dimension); - } else if (isTuring(options.macro)) { + } else if (isTuring(options.macro) || isAmpere(options.macro)) { // Volta mma swizzle requires the broadcast dimension to // participate, which is not true in Turing+. So the two // cases w/ or w/o the broadcast are supported here for @@ -220,7 +220,7 @@ void WarpMmaSwizzler::scheduleMmaWarpOutput( setWarpMapped(tv, 5); } break; - case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Ampere_16_8_16: scheduleTuringM16N8K16MmaWarpOutput(tv, options); if (tv->definition()->isA()) { setWarpMapped(tv, 4); @@ -241,7 +241,7 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) { case MmaOptions::MacroType::Volta_16_16_4: scheduleVoltaOperandRead(tv, options); break; - case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Ampere_16_8_16: scheduleTuringOperandRead(tv, options); break; default: @@ -399,7 +399,7 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { : isOperandTransposed(options); // Check mma option is supported TORCH_CHECK( - options.macro == MmaOptions::MacroType::Turing_16_8_16, + options.macro == MmaOptions::MacroType::Ampere_16_8_16, "scheduleLdMatrix: unknown macro for ldmatrix"); if (options.operand == MmaOptions::Operand::A) { diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index ce8f089d6f25..0f6fc63fa4f6 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -890,9 +890,9 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } -// MMA unit test on turing -TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); +// MMA unit test on Ampere +TEST_F(NVFuserTest, FusionAmpereMMATN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); @@ -920,7 +920,7 @@ TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TN); tv2->configureMma(mma_builder.build()); @@ -959,9 +959,9 @@ TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } -// MMA unit test on turing -TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); +// MMA unit test on Ampere +TEST_F(NVFuserTest, FusionAmpereMMATT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); @@ -987,7 +987,7 @@ TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TT); tv2->configureMma(mma_builder.build()); @@ -1032,9 +1032,9 @@ TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } -// MMA unit test on turing -TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); +// MMA unit test on Ampere +TEST_F(NVFuserTest, FusionAmpereMMANT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); @@ -1059,7 +1059,7 @@ TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::NT); tv2->configureMma(mma_builder.build()); @@ -1108,9 +1108,9 @@ TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); } -// Matmul test on turing -TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); +// Matmul test on Ampere +TEST_F(NVFuserTest, FusionAmpereMatmulTN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); @@ -1139,7 +1139,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TN); tv2->configureMma(mma_builder.build()); @@ -1256,9 +1256,9 @@ TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } -// Matmul test on turing -TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); +// Matmul test on Ampere +TEST_F(NVFuserTest, FusionAmpereMatmulTT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); @@ -1287,7 +1287,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TT); tv2->configureMma(mma_builder.build()); @@ -1405,9 +1405,9 @@ TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } -// Matmul test on turing -TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); +// Matmul test on Ampere +TEST_F(NVFuserTest, FusionAmpereMatmulNT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); @@ -1434,7 +1434,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::NT); tv2->configureMma(mma_builder.build()); @@ -1555,9 +1555,9 @@ TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } -// Matmul-Matmul fusion test on turing -TEST_F(NVFuserTest, FusionMatmulMatmulTuring_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); +// Matmul-Matmul fusion test on Ampere +TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); @@ -1605,16 +1605,16 @@ TEST_F(NVFuserTest, FusionMatmulMatmulTuring_CUDA) { gemm_tile1.warp_tile = GemmTile(64, 32, 32); gemm_tile2.warp_tile = GemmTile(64, 16, 64); - // Using turing mma macro + // Using Ampere mma macro gemm_tile2.instruction_tile = GemmTile(16, 8, 16); gemm_tile2.instruction_tile = GemmTile(16, 8, 16); auto mma_builder1 = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile1) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile1) .layout(MmaOptions::MmaInputLayout::TN); auto mma_builder2 = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile2) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile2) .layout(MmaOptions::MmaInputLayout::TN); tv3->configureMma(mma_builder1.build()); @@ -1830,10 +1830,10 @@ TEST_F(NVFuserTest, FusionMatmulMatmulTuring_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); } -// Simplified Matmul-Softmax-Matmul test on turing +// Simplified Matmul-Softmax-Matmul test on Ampere // (To be extended in follow ups) -TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulTuring_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); +TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); @@ -1905,15 +1905,15 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulTuring_CUDA) { // Distribute to 2x2 warps gemm_tile.warp_tile = GemmTile(16, 64, 32); - // Using turing mma macro + // Using Ampere mma macro gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder1 = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TN); auto mma_builder2 = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TN); tv3->configureMma(mma_builder1.build()); @@ -2210,7 +2210,7 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulTuring_CUDA) { } // Matmul test on ampere, using ampere memory ops -TEST_F(NVFuserTest, FusionAmpereGemmTN_CUDA) { +TEST_F(NVFuserTest, FusionAmpereGemmTNcpAsync_CUDA) { NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; @@ -2241,7 +2241,7 @@ TEST_F(NVFuserTest, FusionAmpereGemmTN_CUDA) { gemm_tile.instruction_tile = GemmTile(16, 8, 16); auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TN); tv2->configureMma(mma_builder.build()); From d719bd84e77dda6c88a45766056e3a1e386435c5 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 25 Apr 2022 13:43:29 -0700 Subject: [PATCH 24/57] comment and naming --- .../csrc/jit/codegen/cuda/lower_predicate.cpp | 25 +++++++---- .../jit/codegen/cuda/lower_validation.cpp | 14 +++++-- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 41 ++++++++++++++----- 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index a21743e0dd2d..8daea90b4683 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -48,15 +48,22 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { setWritePredicate(expr, conditional); } - // This is special handling for async copy from gmem to smem. - // The actual copy part is async to the gpu threads so we - // cannot have both the copy and the initialization as we - // initialize for other ops. This current approach inverts - // the predicate to make sure that the init and copy paths are - // never on at the same time, but as a side effect this predicate - // can only guard the gmem access and predicate removal pass - // will assert that the usage of cpasync is limited to cases - // where the smem consumer predicate is not needed. + // Today for vectorized support the pattern is: + // Initialize buffer -> predicated load + // For memcpy async: + // If we initialized and then loaded (without sync) it would be undefined + // behavior. + // Initialize only the "virtual out of boundary" accesses. + // Memory allocated, but outside the virtual tensor space. + // Virtual tensor space today is effectively what would be allocated in + // global memory. Then only copy the "within bound" accesses. + // This is a WAR today based on how our system is set up. + // We would want to have a separate concept of SMEM space from Virtual or + // GMEM space, so that we know we're only working with the allocated + // SMEM. + // If we hit outside the allocated SMEM bad things happen. + // Today asserting in predicate removal making sure that the virtual and + // SMEM boundaries line up based on the IterDomains. // // TODO: in a follow up we need to extend the predicate // infrastructure to generate predicate for both gmem diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 9c619de12da4..27d5d4d1f4c4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -942,7 +942,15 @@ void validateMmaTensors(MmaOp* mma) { //! mma's. //! This restriction will eventually not //! be necessary once the scatter swizzle is ready. -void validateTuringInput(TensorView* tv) { +void validateTuringMmaInput(TensorView* tv) { + // Pattern matching here to make sure LDMatrix is the right format. + // Format is done through swizzling in the scheduling and + // we check that swizzling to make sure it's correctly setup for LDMatrix. + // We could in theory support patterns LDMatrix doesn't support, + // but that would also mean the MMA isn't supported and + // so we would have to lower to something completely different. + + // MemCpy async is a more generic utility that we can use. // Currently only allowed input paths are: // ldmatrix -> mma or // ldmatrix -> broadcast -> mma @@ -1031,8 +1039,8 @@ void validateMma(Fusion* fusion) { // Check that operands come from ldmatrix, can be // relaxed once swizzles can be labeled on iterdomains. - validateTuringInput(mma->inA()->as()); - validateTuringInput(mma->inB()->as()); + validateTuringMmaInput(mma->inA()->as()); + validateTuringMmaInput(mma->inB()->as()); break; default: TORCH_INTERNAL_ASSERT(false, "validate mma: unsupported macro"); diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 26157e89a2b1..8a3e93e54196 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -7,19 +7,32 @@ namespace Turing { namespace util { -// Special utility for ldmatrix and cp_async -DEVICE_INLINE unsigned toSmem(const void* ptr) { - unsigned smem_ptr; - +// Utility for converting generic pointer to SMEM pointer in PTX. +// We should review vectorized load/stores with shared memory. +// SMEM memory movement PTX is only Global -> SMEM, SMEM -> Local, Local -> +// SMEM, and this is needed for these PTX instructions to provide the SMEM +// pointer. +DEVICE_INLINE unsigned toSmem(const void* raw_ptr) { + unsigned smem_ptr_uint; asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" - : "=r"(smem_ptr) - : "l"(ptr)); + : "=r"(smem_ptr_uint) + : "l"(raw_ptr)); - return smem_ptr; + return smem_ptr_uint; } } // namespace util +// Load Matrix (per warp instruction) is to take data from SMEM to Local Memory. +// Automatically handles vectorized loads/stores in the MMA operation. +// Loads 8x8 matrix into a warp. Thread 0-7 provide the ptr that is the start +// of each row. All other threads can simply point to something valid +// (including 0). +// The x2 modifier on the instruction will actually load 2x8 rows to make a +// 16x8, +// then thread 0-15 will specify the start of each row. +// Finally is an x4 modifier producing a 32x8 using addrs from 0-31 in each +// warp. DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) { uint2& val = reinterpret_cast(out); unsigned addr = util::toSmem(ptr); @@ -28,6 +41,9 @@ DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) { : "r"(addr)); } +// Same as previous, 8x8 matrix is vectorized loaded, then scattered (to perform +// transpose) so threads will hold 2 values down a column (instead of the +// previous instruction that's across a row). DEVICE_INLINE void ldMatrixT(Array<__half, 4, 4>& out, void const* ptr) { uint2& val = reinterpret_cast(out); unsigned addr = util::toSmem(ptr); @@ -67,17 +83,22 @@ namespace util { // Special utility for cp_async DEVICE_INLINE unsigned toSmem(void* ptr) { - unsigned smem_ptr; + unsigned smem_ptr_uint; + // Declare 64 bit register smem_ptr + // Convert the input to a shared memory pointer + // Convert to unsigned 32 bit pointer asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" - : "=r"(smem_ptr) + : "=r"(smem_ptr_uint) : "l"(ptr)); - return smem_ptr; + return smem_ptr_uint; } } // namespace util +// Global to SMEM load that is asynchronous, +// not guaranteed to be completed until cpAsyncBarrier() is called. template DEVICE_INLINE void cpAsync( Array* smem_ptr, From a2c28c947c7c201cb15e817437adeb70b6a258c0 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 25 Apr 2022 15:26:20 -0700 Subject: [PATCH 25/57] leave cp async in a separate PR --- torch/csrc/jit/codegen/cuda/codegen.cpp | 14 -- .../jit/codegen/cuda/lower_double_buffer.cpp | 14 +- .../jit/codegen/cuda/lower_double_buffer.h | 22 --- .../jit/codegen/cuda/lower_insert_syncs.cpp | 36 +---- .../csrc/jit/codegen/cuda/lower_predicate.cpp | 71 --------- .../cuda/lower_predicate_elimination.cpp | 13 -- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 16 -- .../jit/codegen/cuda/lower_validation.cpp | 6 +- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 52 ------- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 139 ------------------ torch/csrc/jit/codegen/cuda/type.cpp | 2 - torch/csrc/jit/codegen/cuda/type.h | 5 - 12 files changed, 3 insertions(+), 387 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 60b5cfb34b13..7d4d11a1ade8 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -487,14 +487,6 @@ class CudaKernelGenerator : private OptOutConstDispatch { return ss.str(); } - void genCpAsync(const UnaryOp* uop, int vec_size) { - auto dtype = uop->in()->getDataType().value(); - - indent() << "Ampere::cpAsync(" - << genVectorPointer(uop->out(), dtype, vec_size) << "," - << genVectorPointer(uop->in(), dtype, vec_size) << ");\n"; - } - void genLdMatrix(const UnaryOp* uop, int vector_word_size) { auto dtype = uop->in()->getDataType().value(); indent() << "Turing::ldMatrix"; @@ -554,7 +546,6 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (vectorize_op) { TORCH_INTERNAL_ASSERT( uop->getUnaryOpType() == UnaryOpType::Set || - uop->getUnaryOpType() == UnaryOpType::CP_ASYNC || uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT, "Cannot vectorize operations that are not sets. ", @@ -575,11 +566,6 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (is_vector_op) { // Note: Non-vectorized cp async isn't yet supported. // will support in a follow up. - if (uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { - genCpAsync(uop, vector_word_size); - return; - } - // TODO: do we want to define a unary op called memory/copy/ldst? if (uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT) { diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index bc9f52c9fb26..f02c35964fb3 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -69,7 +69,6 @@ void validateDoubleBufferedTensor(const TensorView* tv) { TORCH_INTERNAL_ASSERT( def->isA() && def->as()->getUnaryOpType() == UnaryOpType::Set || - def->as()->getUnaryOpType() == UnaryOpType::CP_ASYNC || def->as()->getUnaryOpType() == UnaryOpType::LD_MATRIX || def->as()->getUnaryOpType() == UnaryOpType::LD_MATRIXT, "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set is supported: ", @@ -128,15 +127,6 @@ class DoubleBufferFusionInspector : private IterVisitor { TORCH_INTERNAL_ASSERT( tv->definition(), "Fusion input shouldn't be double buffered.", tv); - // Record that there is a double buffered tensor that is a consumer - // of async gmem load. The double buffer sync will have to sync - // with the async load as well. - if (auto uop = dynamic_cast(tv->definition())) { - if (uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { - db_info_.setAsyncGmemDoubleBuffer(); - } - } - validateDoubleBufferedTensor(tv); auto db_axis = getDoubleBufferAxis(tv); @@ -424,9 +414,7 @@ class DoubleBufferInserter : private kir::ExprMutator { if (write_to_smem) { // Here the initial sync before entering double buffer loop is // inserted. It will also need to sync with async gmem loads. - auto sync = IrBuilder::create( - false, - GpuLower::current()->doubleBufferInfo().hasAsyncGmemDoubleBuffer()); + auto sync = IrBuilder::create(false); registerInsertBefore(double_buffer_loop, sync); } diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h index d94184141e76..96bc247f4ff6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h @@ -128,34 +128,12 @@ class TORCH_CUDA_CU_API DoubleBufferInfo { Val* getOriginalAllocSize(const TensorView* tv); - //! Sets the async gmem load flag indicating that - //! a cpasync load has been double buffered. - void setAsyncGmemDoubleBuffer() { - has_async_gmem_double_buffer_ = true; - } - - //! Check if any double buffered tensor is from - //! async gmem load. The current analysis is - //! simplistic to begin with, meaning we always - //! sync with the gmem load if we have any - //! sharedmem double buffered tensor is async - //! loaded. - //! TODO: will need to extend to support more complex - //! pipeline patterns in follow ups. - bool hasAsyncGmemDoubleBuffer() { - return has_async_gmem_double_buffer_; - } - private: TvInfo& getTvInfo(const TensorView* tv); private: //! Keeps track of information for lowering double buffered tensors std::unordered_map map_; - - //! Tracks if any double buffered tensor is consumer of async - //! gmem load. - bool has_async_gmem_double_buffer_ = false; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index c756a0f51afc..4a4d1e702dda 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -134,19 +134,6 @@ class WarSyncInserter : private kir::ExprMutator { for (const auto& entry : lower_alloc_info_map) { alloc_map_.insert(entry.first); } - - // Currently the RAW sync is removed in the double buffer loop if - // a shared mem tensor is double buffered. In this case the sync - // inserted by war pass is actually both war and raw together. - // This part of the logic extension is for the case when a sharedmem - // tensor is double buffered and async loaded, the double buffer sync - // in this case need to also sync with the async loads. - // Note & TODO: - // might at some point want to separate out as another category like - // "double buffer sync" if we need to keep extending the - // war sync path but actually to prevent raw hazard. - need_to_sync_gmem_ = - GpuLower::current()->doubleBufferInfo().hasAsyncGmemDoubleBuffer(); kir::ExprMutator::traverseAndInsert(exprs); } @@ -533,10 +520,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { sync_expr = IrBuilder::create( sync_bitmap, maybe_alloc->buffer()); } else { - bool need_gmem_sync = gmem_sync_before_.count(expr); - sync_expr = IrBuilder::create( - false, // is not war sync - need_gmem_sync); + sync_expr = IrBuilder::create(false); // is not war sync } // The expressions in last_writes are those we're protecting the read @@ -705,18 +689,6 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { auto last_smem_writes = isModifiedSharedMemory(smem, expr->inputs()); - // RAW sync insertion in the case where the load isn't double buffered. - // This sync is needed when any of the producer writes is async. - bool need_gmem_sync = std::any_of( - last_smem_writes.begin(), last_smem_writes.end(), [](Expr* expr) { - if (auto uop = dynamic_cast(expr)) { - if (uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { - return true; - } - } - return false; - }); - if (!last_smem_writes.empty()) { TORCH_INTERNAL_ASSERT( prev_tv_expr != nullptr, @@ -727,9 +699,6 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { bitmap.set(ParallelType::TIDz); sync_before_.emplace_back(std::make_pair(expr, bitmap)); last_writes_.push_back(last_smem_writes); - if (need_gmem_sync) { - gmem_sync_before_.insert(expr); - } smem.clear(); } @@ -769,9 +738,6 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { //! it is not placed before those write expressions. std::deque> last_writes_; - //! Keep track of expressions that must be followed by a gmem sync - std::unordered_set gmem_sync_before_; - public: static std::vector insert(const std::vector& loop_nests) { ReadAfterWriteSyncs inserter(loop_nests); diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 8daea90b4683..6089a5604d92 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -48,80 +48,9 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { setWritePredicate(expr, conditional); } - // Today for vectorized support the pattern is: - // Initialize buffer -> predicated load - // For memcpy async: - // If we initialized and then loaded (without sync) it would be undefined - // behavior. - // Initialize only the "virtual out of boundary" accesses. - // Memory allocated, but outside the virtual tensor space. - // Virtual tensor space today is effectively what would be allocated in - // global memory. Then only copy the "within bound" accesses. - // This is a WAR today based on how our system is set up. - // We would want to have a separate concept of SMEM space from Virtual or - // GMEM space, so that we know we're only working with the allocated - // SMEM. - // If we hit outside the allocated SMEM bad things happen. - // Today asserting in predicate removal making sure that the virtual and - // SMEM boundaries line up based on the IterDomains. - // - // TODO: in a follow up we need to extend the predicate - // infrastructure to generate predicate for both gmem - // and smem, and the predicate removal will need to - // be extended as well for the perf critical regions. - if (isPredicatedInitForCpAsync(expr)) { - invertPredicateForGmemToSharedMemInitialize(expr); - } - kir::IrVisitor::handle(expr); } - void invertPredicateForGmemToSharedMemInitialize(Expr* expr) { - auto pred = expr->predicate()->value(); - auto invert = SimplifyingIrBuilder::notExpr(pred); - expr->predicate()->setValue(invert->as()); - } - - // Detect if this expr is an initialization for vectorized - // cp asyc with predicates. - bool isPredicatedInitForCpAsync(Expr* expr) { - if (expr->predicate() == nullptr) { - return false; - } - if (auto ite = dynamic_cast(expr)) { - if (!ite->elseBody().empty()) { - return false; - } - if (ite->predicate()->predicate_type() != PredicateType::Vectorize) { - return false; - } - for (auto inner_expr : ite->thenBody().exprs()) { - // Checking if this expr writes scalar to shared mem tensor, - // while fuser expr read from gmem to shared mem. - if (auto kir_uop = dynamic_cast(inner_expr)) { - // Initialization ops assumed to have scalar input. - if (!kir_uop->in()->isScalar()) { - continue; - } - if (auto out_tv = dynamic_cast(kir_uop->out())) { - auto fuser_expr = out_tv->view()->definition(); - if (fuser_expr == nullptr) { - // This shouldn't hit but added this skip here in case - // any future kir ops have this behavior. - continue; - } - if (auto fuser_uop = dynamic_cast(fuser_expr)) { - if (fuser_uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { - return true; - } - } - } - } - } - } - return false; - } - void setWritePredicate(Expr* expr, Bool* read_cond) { if (expr->writePredicate() != nullptr) { auto write_cond = generateConditional(expr->writePredicate()); diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 8bd2a49fd202..65eb97487b6d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -104,19 +104,6 @@ class PredicateAnalyzer : public OptOutDispatch { } } - // Restricted usage for cp async. - // TODO: - // cp async initialization and data load path are async - // so would need support for multiple predicate to handle - // generic support. Currently only support the use cases where - // the smem address access will never be out of bound. - auto consumer_def = dynamic_cast(consumer->definition()); - if (consumer_def && - consumer_def->getUnaryOpType() == UnaryOpType::CP_ASYNC) { - TORCH_INTERNAL_ASSERT( - !needs_sharedmem_addr_pred, "unsupported usage for cp async"); - } - if (needs_sharedmem_addr_pred) { return true; } diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index e59274f72f1a..6da186a6ad66 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -121,16 +121,6 @@ void UnrollPass::handle(Expr* expr) { return false; }; - // Check if the given op is an async copy - auto is_cpasync_op = [](Expr* expr) { - if (auto uop = dynamic_cast(expr)) { - if (uop->getUnaryOpType() == UnaryOpType::CP_ASYNC) { - return true; - } - } - return false; - }; - bool is_ld_matrix = false; bool is_ld_matrix_producer = false; if (is_ld_matrix_op(expr)) { @@ -160,12 +150,6 @@ void UnrollPass::handle(Expr* expr) { val->as()->getMemoryType() == MemoryType::Global; })) { for (auto out_tv : ir_utils::filterByType(expr->outputs())) { - if (is_cpasync_op(out_tv->definition())) { - // cp async op has initialization outside of the predicate protected - // boundary so no need to ensure predicate removal. - break; - } - for (auto use : expr->fusion()->unordered_uses(out_tv)) { if (is_ld_matrix_op(use)) { TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 27d5d4d1f4c4..d6c3ce9342a7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -565,10 +565,9 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) { tv->definition() == nullptr || (tv->definition()->isA() && (uop_type == UnaryOpType::Set || - uop_type == UnaryOpType::CP_ASYNC || uop_type == UnaryOpType::LD_MATRIX || uop_type == UnaryOpType::LD_MATRIXT)), - "Vectorized accesses cannot be inline with computation, they are only supported with a Set operation.", + "Vectorized accesses cannot be inline with computation, they are only supported with a Set like operation.", "TensorView: ", tv); } @@ -1011,9 +1010,6 @@ void validateArchMemoryOp(UnaryOp* uop) { validateMinimumArch(7, 5); validateLdMatrixOutput(uop->out()->as()); return; - case UnaryOpType::CP_ASYNC: - validateMinimumArch(8, 0); - return; default: return; } diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 8a3e93e54196..060f2920b0e3 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -73,56 +73,4 @@ DEVICE_INLINE void ldMatrixT(Array<__half, 8, 8>& out, void const* ptr) { #endif // Arch 75 -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - -namespace Ampere { - -// MMA instruction wrappers (sm_80+): - -namespace util { - -// Special utility for cp_async -DEVICE_INLINE unsigned toSmem(void* ptr) { - unsigned smem_ptr_uint; - - // Declare 64 bit register smem_ptr - // Convert the input to a shared memory pointer - // Convert to unsigned 32 bit pointer - asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" - : "=r"(smem_ptr_uint) - : "l"(ptr)); - - return smem_ptr_uint; -} - -} // namespace util - -// Global to SMEM load that is asynchronous, -// not guaranteed to be completed until cpAsyncBarrier() is called. -template -DEVICE_INLINE void cpAsync( - Array* smem_ptr, - void const* gmem_ptr) { - unsigned smem_addr = util::toSmem(&(smem_ptr->array[0])); - constexpr int byte_size = sizeof(dtype) * len; - - static_assert( - byte_size == 4 || byte_size == 8 || byte_size == 16, - "cp_async : unsupported byte size"); - - asm volatile( - "cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(smem_addr), - "l"(gmem_ptr), - "n"(byte_size)); -} - -// TODO: Might have a different category of sync if we want to build out this: -DEVICE_INLINE void cpAsyncBarrier() { - asm volatile("cp.async.wait_all;"); -} - -} // namespace Ampere - -#endif // Arch 80 - #undef DEVICE_INLINE diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 0f6fc63fa4f6..cce768e5e6e8 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -2209,145 +2209,6 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001)); } -// Matmul test on ampere, using ampere memory ops -TEST_F(NVFuserTest, FusionAmpereGemmTNcpAsync_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - - int M = 255, N = 511, K = 88; - - // [M,K] - auto tv0 = makeContigTensor(2, DataType::Half); - // [N,K] - auto tv1 = makeContigTensor(2, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [M,N,K] - auto tv0b = broadcast(tv0, {false, true, false}); - auto tv1b = broadcast(tv1, {true, false, false}); - - // Leaving both sets of mma inputs for volta outside - // currently since they need to be swizzled. - auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) - .layout(MmaOptions::MmaInputLayout::TN); - - tv2->configureMma(mma_builder.build()); - - auto tv0cw = tv0->cacheAfter(UnaryOpType::CP_ASYNC); - auto tv0cr = tv0cw->cacheAfter(UnaryOpType::LD_MATRIX); - auto tv1cw = tv1->cacheAfter(UnaryOpType::CP_ASYNC); - auto tv1cr = tv1cw->cacheAfter(UnaryOpType::LD_MATRIX); - auto tv2c = tv2->cacheBefore(); - - // Make a CTA tile - // ------------------------------------------------------------------ - // [M,N] - tv2->split(-2, gemm_tile.cta_tile.m); - tv2->split(-1, gemm_tile.cta_tile.n); - - // 0 1 2 3 - // [Mo,M128, No, N128] - tv2->reorder({{1, 2}, {2, 1}}); - - // 0 1 2 3 - // [Mo,No, M128, N128] - tv0->computeAt(tv2, 2); - tv1->computeAt(tv2, 2); - - // Order K - // 0 1 2 3 4 5 - // [Mo,No, M128, N128, Ko, K32] - tv2c->split(-1, gemm_tile.cta_tile.k); - tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); - - // 0 1 2 3 4 5 - // [Mo,No, Ko M128, N128, K32] - tv0cw->computeAt(tv2c, 3); - tv1cw->computeAt(tv2c, 3); - - // Make warp tile: - // ------------------------------------------------------------------------- - scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); - scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( - tv2, gemm_tile); - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] - tv0cr->computeAt(tv2c, -4); - tv1cr->computeAt(tv2c, -4); - - // Schedule gmem read and smem write: - // --------------------------------------------------------------------------- - // [Mo,Ko,M,K] - tv0cw->merge(-2); - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv0cw, gemm_tile, 8); - tv0cw->setMemoryType(MemoryType::Shared); - // [Mo,Ko,i,wy,wx,v] - - // [No,Ko,N,K] - tv1cw->merge(-2); - // [No,Ko,i,wy,wx,v] - scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( - tv1cw, gemm_tile, 8); - tv1cw->setMemoryType(MemoryType::Shared); - // Schedule mma input - // --------------------------------------------------------------------------- - tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - // [... Mi, Ni, Ki] - tv0b->reorder({{-2, -3}, {-3, -2}}); - tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); - - tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); - - // Schedule mma output - // --------------------------------------------------------------------------- - tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); - - // Parallelize - // 0 1 2 3 4 5 6 7 8 9 10 - // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] - tv2c->axis(3)->parallelize(ParallelType::TIDz); - tv2c->axis(4)->parallelize(ParallelType::TIDy); - - // Parallelize - // 0 1 2 3 4 5 6 7 - // [Mo No Mwo Nwo Mw Nw (Mi Ni)] - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::BIDy); - tv2->axis(2)->parallelize(ParallelType::TIDz); - tv2->axis(3)->parallelize(ParallelType::TIDy); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto t0 = at::randn({M, K}, options); - auto t1 = at::randn({N, K}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - - auto cg_outputs = fe.runFusion({t0, t1}); - - auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); - - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); -} - #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 67702380a9ca..b5214151f219 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -344,8 +344,6 @@ static const char* unary_op_type2string(UnaryOpType t) { return "tanh"; case UnaryOpType::Trunc: return "trunc"; - case UnaryOpType::CP_ASYNC: - return "cpasync"; case UnaryOpType::LD_MATRIX: return "ldmatrix"; case UnaryOpType::LD_MATRIXT: diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 894d68b1bf86..d6398a225e2a 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -155,11 +155,6 @@ enum class UnaryOpType { Not, // Memory ops, - // TODO & Note: Should consider moving these into - // a separate/new IR node, maybe LDST? Also should be moved - // are most of the usage cases of UnaryOp::Set, - // which are indeed representing memory read writes. - CP_ASYNC, LD_MATRIX, LD_MATRIXT }; From d28574c2c8f76e5adcdb77cc5ad4d43436a0fa98 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 25 Apr 2022 15:43:24 -0700 Subject: [PATCH 26/57] cleanup --- torch/csrc/jit/codegen/cuda/codegen.cpp | 5 ---- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 6 ++--- torch/csrc/jit/codegen/cuda/kernel_ir.h | 23 +------------------ .../jit/codegen/cuda/lower_insert_syncs.cpp | 3 +-- 4 files changed, 4 insertions(+), 33 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 7d4d11a1ade8..f477aed0b890 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -1854,11 +1854,6 @@ class CudaKernelGenerator : private OptOutConstDispatch { } void handle(const kir::BlockSync* sync) final { - // Issue a barrier to sync with async loads, - // see [Note: dma sync with gmem loads] - if (sync->syncGmem()) { - indent() << "Ampere::cpAsyncBarrier();\n"; - } // Use a custom synchronization method if enabled if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) { indent() << "block_sync::sync();\n"; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 43af0282aeda..86291be901b5 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -78,10 +78,8 @@ TensorIndex::TensorIndex( } } -BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync, bool gmem_sync) - : Expr(passkey, ExprType::BlockSync), - war_sync_(war_sync), - gmem_sync_(gmem_sync) { +BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) + : Expr(passkey, ExprType::BlockSync), war_sync_(war_sync) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 288122b9a4e5..636d5a18771d 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -244,36 +244,15 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { // class TORCH_CUDA_CU_API BlockSync final : public Expr { public: - explicit BlockSync( - IrBuilderPasskey passkey, - bool war_sync = false, - bool gmem_sync = false); + explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false); bool isWarHazardSync() const { return war_sync_; } - bool syncGmem() const { - return gmem_sync_; - } - private: // TODO: war_sync_ is only used for testing/validation purposes. bool war_sync_ = false; - - //! [Note: Dma sync with gmem loads] - //! Indicates if this block sync also synchronizes with asynchronous - //! copy ops from global mem to shared mem. - //! Currently making this a parameter in BlockSync so that we always - //! sync the blocks after synchronizing with the async gmem loads to - //! avoid having to model more complex sync patterns. - //! For most cases it could be assumed that output of async gmem op - //! in the shared memory will be accessed by other threads as well - //! since otherwise wouldn't need to load to shared mem. - //! An exception might be if we use shared mem as extension of - //! register space and we could model this sync as a separate op - //! if we see a use case. - bool gmem_sync_ = false; }; // Synchronize all blocks in device, implies cooperative group launch is diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index 4a4d1e702dda..a146432d46e4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -275,8 +275,7 @@ class WarSyncInserter : private kir::ExprMutator { // WAR Sync is necessary in this loop, register its insertion. if (insert_sync) { - auto sync_expr = - IrBuilder::create(true, need_to_sync_gmem_); + auto sync_expr = IrBuilder::create(true); kir::ExprMutator::registerInsertAfter( for_loop->body().exprs().back(), sync_expr, &for_loop->body()); handle(sync_expr); From 2fa3a92f69d73bb3128ffbd666fb79cb034df4a6 Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 26 Apr 2022 11:27:59 -0700 Subject: [PATCH 27/57] move memory op to separate IR node --- torch/csrc/jit/codegen/cuda/codegen.cpp | 67 ++++++++++++++----- torch/csrc/jit/codegen/cuda/dispatch.cpp | 15 +++++ torch/csrc/jit/codegen/cuda/dispatch.h | 4 ++ torch/csrc/jit/codegen/cuda/ir_builder.cpp | 1 + torch/csrc/jit/codegen/cuda/ir_cloner.cpp | 4 ++ torch/csrc/jit/codegen/cuda/ir_cloner.h | 1 + .../jit/codegen/cuda/ir_interface_nodes.h | 6 +- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 30 +++++++++ torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 5 ++ torch/csrc/jit/codegen/cuda/ir_iostream.h | 1 + torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 19 ++++++ torch/csrc/jit/codegen/cuda/ir_utils.cpp | 15 +++++ .../jit/codegen/cuda/lower_double_buffer.cpp | 4 +- torch/csrc/jit/codegen/cuda/lower_index.cpp | 6 ++ torch/csrc/jit/codegen/cuda/lower_index.h | 1 + .../jit/codegen/cuda/lower_insert_syncs.cpp | 3 - .../cuda/lower_predicate_elimination.cpp | 3 +- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 39 +++-------- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 18 +++++ torch/csrc/jit/codegen/cuda/lower_utils.h | 4 ++ .../jit/codegen/cuda/lower_validation.cpp | 33 ++++----- torch/csrc/jit/codegen/cuda/mma_type.cpp | 8 +-- torch/csrc/jit/codegen/cuda/mma_type.h | 2 +- torch/csrc/jit/codegen/cuda/mutator.cpp | 14 ++++ torch/csrc/jit/codegen/cuda/root_domain_map.h | 4 ++ .../jit/codegen/cuda/scheduler/mma_utils.cpp | 33 ++++----- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 20 ++++-- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 16 ++--- torch/csrc/jit/codegen/cuda/type.cpp | 23 +++++-- torch/csrc/jit/codegen/cuda/type.h | 12 ++-- 30 files changed, 294 insertions(+), 117 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index f477aed0b890..8be6c8f09d08 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -487,15 +487,16 @@ class CudaKernelGenerator : private OptOutConstDispatch { return ss.str(); } - void genLdMatrix(const UnaryOp* uop, int vector_word_size) { - auto dtype = uop->in()->getDataType().value(); + void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) { + auto dtype = ldst->in()->getDataType().value(); indent() << "Turing::ldMatrix"; - if (uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT) { + if (ldst->opType() == LoadStoreOpType::LdMatrixTranspose) { code_ << "T"; } code_ << " ("; - code_ << "*" << genVectorPointer(uop->out(), dtype, vector_word_size) << "," - << "&" << gen(uop->in()) << ");\n"; + code_ << "*" << genVectorPointer(ldst->out(), dtype, vector_word_size) + << "," + << "&" << gen(ldst->in()) << ");\n"; } void handle(const UnaryOp* uop) final { @@ -545,9 +546,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (vectorize_op) { TORCH_INTERNAL_ASSERT( - uop->getUnaryOpType() == UnaryOpType::Set || - uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || - uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT, + uop->getUnaryOpType() == UnaryOpType::Set, "Cannot vectorize operations that are not sets. ", "Use cacheBefore and cacheAfter to store/load with vectorized reads into buffers."); is_vector_op = true; @@ -564,15 +563,6 @@ class CudaKernelGenerator : private OptOutConstDispatch { } if (is_vector_op) { - // Note: Non-vectorized cp async isn't yet supported. - // will support in a follow up. - // TODO: do we want to define a unary op called memory/copy/ldst? - if (uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || - uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT) { - genLdMatrix(uop, vector_word_size); - return; - } - auto out_tv = uop->out()->as()->view(); if (uop->in()->isScalar()) { // Note: @@ -1138,6 +1128,49 @@ class CudaKernelGenerator : private OptOutConstDispatch { } } + void handle(const LoadStoreOp* ldst) { + // TODO: + // Need to gradually merge the code path of this + // with UnaryOp::Set for vectorization. + // There is quite a bit of possible clean up. + bool vectorize_op = false; + size_t vector_word_size = 1; + auto ti = ldst->out()->as(); + + // Check vectorization and set vector word size + for (auto id : ti->view()->domain()->domain()) { + if (!isParallelTypeVectorize(id->getParallelType())) { + continue; + } + + ExpressionEvaluator expr_eval(id->fusion()); + auto vector_size_optional = expr_eval.evaluate(id->extent()); + + TORCH_INTERNAL_ASSERT( + vector_size_optional.has_value(), + "Could not evaluate constant value bound to vectorized dim."); + + TORCH_INTERNAL_ASSERT( + id->getParallelType() != ParallelType::MisalignedVectorize, + "LoadStoreOp: no support yet for mis-aligned vectorization"); + vector_word_size = vector_size_optional.value(); + vectorize_op = true; + break; + } + + // Dispatch instruction generation: + switch (ldst->opType()) { + case LoadStoreOpType::LdMatrix: + case LoadStoreOpType::LdMatrixTranspose: + TORCH_INTERNAL_ASSERT( + vectorize_op, "LdMatrix: Vectorization required: ", ldst); + genLdMatrix(ldst, vector_word_size); + break; + default: + TORCH_INTERNAL_ASSERT(false, "LoadStoreOp: Unknown op type"); + } + } + void handle(const WelfordOp* wop) final { TORCH_INTERNAL_ASSERT(wop->out()->isA()); diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index dc7ac6403d65..d0d73277bf63 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -104,6 +104,9 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::WelfordOp: ptr(handler)->handle(expr->as()); return; + case ExprType::LoadStoreOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::MmaOp: ptr(handler)->handle(expr->as()); return; @@ -245,6 +248,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::WelfordOp: ptr(handler)->handle(expr->as()); return; + case ExprType::LoadStoreOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::MmaOp: ptr(handler)->handle(expr->as()); return; @@ -397,6 +403,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) { case ExprType::WelfordOp: ptr(mutator)->mutate(expr->as()); return; + case ExprType::LoadStoreOp: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::MmaOp: ptr(mutator)->mutate(expr->as()); return; @@ -614,6 +623,9 @@ void OptOutConstDispatch::handle(const ReductionOp* stmt) { void OptOutConstDispatch::handle(const WelfordOp* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const LoadStoreOp* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const MmaOp* stmt) { unhandled(stmt); } @@ -728,6 +740,9 @@ void OptOutDispatch::handle(ReductionOp* stmt) { void OptOutDispatch::handle(WelfordOp* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(LoadStoreOp* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(MmaOp* stmt) { unhandled(stmt); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index c38641cee580..420b490a457c 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -73,6 +73,7 @@ class BinaryOp; class TernaryOp; class ReductionOp; class WelfordOp; +class LoadStoreOp; class MmaOp; class BroadcastOp; class TransposeOp; @@ -133,6 +134,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const TernaryOp* stmt); virtual void handle(const ReductionOp* stmt); virtual void handle(const WelfordOp* stmt); + virtual void handle(const LoadStoreOp* stmt); virtual void handle(const MmaOp* stmt); virtual void handle(const BroadcastOp* stmt); @@ -186,6 +188,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(TernaryOp* stmt); virtual void handle(ReductionOp* stmt); virtual void handle(WelfordOp* stmt); + virtual void handle(LoadStoreOp* stmt); virtual void handle(MmaOp* stmt); virtual void handle(BroadcastOp* stmt); @@ -280,6 +283,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(TernaryOp*); virtual void mutate(ReductionOp*); virtual void mutate(WelfordOp*); + virtual void mutate(LoadStoreOp*); virtual void mutate(MmaOp*); virtual void mutate(BroadcastOp*); diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp index c17ff0de44a4..64a656496567 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -63,6 +63,7 @@ IR_BUILDER_INSTANTIATE(BinaryOp) IR_BUILDER_INSTANTIATE(TernaryOp) IR_BUILDER_INSTANTIATE(ReductionOp) IR_BUILDER_INSTANTIATE(WelfordOp) +IR_BUILDER_INSTANTIATE(LoadStoreOp) IR_BUILDER_INSTANTIATE(MmaOp) IR_BUILDER_INSTANTIATE(BroadcastOp) diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 1ddc4feb90da..622908ab7889 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -112,6 +112,10 @@ void IrCloner::handle(const WelfordOp* op) { clone_ = IrBuilder::clone(op, this); } +void IrCloner::handle(const LoadStoreOp* op) { + clone_ = IrBuilder::clone(op, this); +} + void IrCloner::handle(const MmaOp* op) { clone_ = IrBuilder::clone(op, this); } diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index 3f50cd48e93b..153f93f863cd 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -74,6 +74,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { void handle(const BroadcastOp*) override; void handle(const ReductionOp*) override; void handle(const WelfordOp*) override; + void handle(const LoadStoreOp*) override; void handle(const MmaOp*) override; void handle(const TransposeOp*) override; void handle(const ShiftOp*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index d5dc982a6a2a..45873a9fdac0 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -410,14 +410,16 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! //! @param cache_op: memory operator to use for the inserted op between //! the the data tensor and the cache tensor - TensorView* cacheBefore(UnaryOpType cache_op = UnaryOpType::Set); + TensorView* cacheBefore( + c10::optional cache_op = c10::nullopt); //! Create a TensorView after the original tensor. A common use case is to //! read tensor into shared memory or registers. Analogous to TVM Cache_Read //! //! @param cache_op: memory operator to use for the inserted op between //! the the data tensor and the cache tensor - TensorView* cacheAfter(UnaryOpType cache_op = UnaryOpType::Set); + TensorView* cacheAfter( + c10::optional cache_op = c10::nullopt); // For a fusion output with other uses, we want to avoid writing to global // memory and then reading the output again. We write to global memory diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index f58dcb7129d7..f43682358584 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -549,6 +549,36 @@ class TORCH_CUDA_CU_API ViewOp : public Expr { TensorView* const in_ = nullptr; }; +//! This operator explicitly models data movement between +//! state spaces on GPU. Currently the modeled state spaces include +//! global memory, shared memory and register. +//! +//! The main usage of this op is to facilitate generation of hardware +//! accelerated memory ops, i.e. ldmatrix, cp.async and more to come. +class TORCH_CUDA_CU_API LoadStoreOp : public Expr { + public: + LoadStoreOp(IrBuilderPasskey, LoadStoreOpType op_type, Val* out, Val* in); + + LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner); + + Val* out() const { + return out_; + } + + Val* in() const { + return in_; + } + + LoadStoreOpType opType() const { + return load_store_type_; + } + + private: + LoadStoreOpType load_store_type_ = LoadStoreOpType::LdMatrix; + Val* const out_ = nullptr; + Val* const in_ = nullptr; +}; + // Friends for direct access to split class TensorDomain; class ReplayTransformations; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 4cacb8de402d..9fc1afb81b46 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -415,6 +415,11 @@ void IrPrinter::handle(const WelfordOp* wop) { os_ << " )\n"; } +void IrPrinter::handle(const LoadStoreOp* ldst) { + indent() << ldst->out() << " = " << ldst->opType() << "( " << ldst->in() + << " )\n"; +} + void IrPrinter::handle(const BroadcastOp* bop) { indent() << bop->out() << " = broadcast( " << bop->in() << " )\n"; } diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index e25e6ef0f865..2c84652cce1f 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -87,6 +87,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const TernaryOp*) final; void handle(const ReductionOp*) final; void handle(const WelfordOp*) final; + void handle(const LoadStoreOp*) final; void handle(const MmaOp*) final; void handle(const BroadcastOp*) final; void handle(const TransposeOp*) final; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 3b5c99b4c285..c75e32a358ad 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -832,6 +832,25 @@ ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} +LoadStoreOp::LoadStoreOp( + IrBuilderPasskey passkey, + LoadStoreOpType op_type, + Val* out, + Val* in) + : Expr(passkey, ExprType::LoadStoreOp), + load_store_type_(op_type), + out_(out), + in_(in) { + addOutput(out); + addInput(in); +} + +LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + load_store_type_(src->load_store_type_), + out_(ir_cloner->clone(src->out_)), + in_(ir_cloner->clone(src->in_)) {} + IterDomain::IterDomain( IrBuilderPasskey passkey, Val* start, diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 60dea3fa1950..401e8093a1ed 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -328,6 +328,21 @@ struct SubstituteInExpr : public OptInDispatch { welford_expr->isFused()); } + void handle(LoadStoreOp* ldst_expr) final { + TORCH_INTERNAL_ASSERT( + substitute_->isA(), + "All args to view must be TensorView, but received a non-TensorView for replacement: ", + substitute_); + auto in = reference_->sameAs(ldst_expr->in()) + ? substitute_->as() + : ldst_expr->in(); + auto out = reference_->sameAs(ldst_expr->out()) + ? substitute_->as() + : ldst_expr->out(); + expr_ = IrBuilder::create( + ldst_expr->container(), ldst_expr->opType(), out, in); + } + void handle(MmaOp* mma_expr) final { TORCH_INTERNAL_ASSERT( substitute_->isA(), diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index f02c35964fb3..152180218c92 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -69,8 +69,8 @@ void validateDoubleBufferedTensor(const TensorView* tv) { TORCH_INTERNAL_ASSERT( def->isA() && def->as()->getUnaryOpType() == UnaryOpType::Set || - def->as()->getUnaryOpType() == UnaryOpType::LD_MATRIX || - def->as()->getUnaryOpType() == UnaryOpType::LD_MATRIXT, + // Load store op should generally support double buffering. + def->isA(), "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set is supported: ", def->toString()); diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 9e429fb43904..3d302d28e824 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -441,6 +441,12 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { } } +void IndexLowering::handle(const LoadStoreOp* ldst) { + const auto in = lowerSrcIndex(ldst->in(), ldst->out()); + const auto out = lowerDstIndex(ldst->out()); + pushBack(IrBuilder::create(ldst->opType(), out, in)); +} + void IndexLowering::handle(const MmaOp* mma) { const auto a = lowerSrcIndex(mma->inA(), mma->out()); const auto b = lowerSrcIndex(mma->inB(), mma->out()); diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index 94a7961c77ea..6f4bde7ab47a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -38,6 +38,7 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { void handle(const TernaryOp*) final; void handle(const ReductionOp*) final; void handle(const WelfordOp*) final; + void handle(const LoadStoreOp*) final; void handle(const MmaOp*) final; void handle(const BroadcastOp*) final; diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index a146432d46e4..776635d26cae 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -332,9 +332,6 @@ class WarSyncInserter : private kir::ExprMutator { // alias tv, each aliased tv in a unique ca_loop has to be tracked separately // for WAR insertion. std::unordered_map> smem_allocations_; - - //! Tracks if the inserted syncs need to also sync with the gmem async loads. - bool need_to_sync_gmem_ = false; }; class ExprFlattener : private kir::IrVisitor { diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 65eb97487b6d..257ecc3c4080 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -25,8 +25,7 @@ namespace { void assertOnWarpOps(const Expr* expr) { if (auto uop = dynamic_cast(expr)) { TORCH_INTERNAL_ASSERT( - uop->getUnaryOpType() != UnaryOpType::LD_MATRIX || - uop->getUnaryOpType() != UnaryOpType::LD_MATRIXT, + !ir_utils::isLdMatrixOp(expr), "Predicate elimination: cannot eliminate pred for ldmatrix, use exact parallel dims", expr); } diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 6da186a6ad66..c4e3aba72a85 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -110,20 +110,9 @@ void UnrollPass::handle(Expr* expr) { // Vectorized expressions should never use inline predicates kir::Predicate* pred = nullptr; - // Check if the given op is an ldmatrix - auto is_ld_matrix_op = [](Expr* expr) { - if (auto uop = dynamic_cast(expr)) { - if (uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || - uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT) { - return true; - } - } - return false; - }; - bool is_ld_matrix = false; bool is_ld_matrix_producer = false; - if (is_ld_matrix_op(expr)) { + if (ir_utils::isLdMatrixOp(expr)) { // Note & TODO: currently cannot support predicated ldmatrix since it'd // need to make sure all threads in a warp evaluate the predicate // to the same value. @@ -144,22 +133,16 @@ void UnrollPass::handle(Expr* expr) { // data outside of the producer's valid region. // Except for vectorized global loads which should be properly // initialized outside of the predicated region. - if (!std::any_of( - expr->inputs().begin(), expr->inputs().end(), [](Val* val) { - return val->isA() && - val->as()->getMemoryType() == MemoryType::Global; - })) { - for (auto out_tv : ir_utils::filterByType(expr->outputs())) { - for (auto use : expr->fusion()->unordered_uses(out_tv)) { - if (is_ld_matrix_op(use)) { - TORCH_INTERNAL_ASSERT( - GpuLower::current()->predicateElimination().canOmitPredicate( - expr), - "Vectorize: producers of ldmatrix cannot be predicated except for global loads:", - expr); - is_ld_matrix_producer = true; - break; - } + for (auto out_tv : ir_utils::filterByType(expr->outputs())) { + for (auto use : expr->fusion()->unordered_uses(out_tv)) { + if (ir_utils::isLdMatrixOp(use)) { + TORCH_INTERNAL_ASSERT( + GpuLower::current()->predicateElimination().canOmitPredicate( + expr), + "Vectorize: producers of ldmatrix cannot be predicated except for global loads:", + expr); + is_ld_matrix_producer = true; + break; } } } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 4763dd41ad82..30ac12e43db5 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -93,6 +93,7 @@ bool isTvOp(const Expr* expr) { expr->getExprType().value() == ExprType::TernaryOp || expr->getExprType().value() == ExprType::ReductionOp || expr->getExprType().value() == ExprType::WelfordOp || + expr->getExprType().value() == ExprType::LoadStoreOp || expr->getExprType().value() == ExprType::MmaOp || expr->getExprType().value() == ExprType::BroadcastOp || expr->getExprType().value() == ExprType::TransposeOp || @@ -108,6 +109,14 @@ bool isTvOp(const Expr* expr) { return false; } +bool isLdMatrixOp(const Expr* expr) { + if (auto ldst = dynamic_cast(expr)) { + return ldst->opType() == LoadStoreOpType::LdMatrix || + ldst->opType() == LoadStoreOpType::LdMatrixTranspose; + } + return false; +} + TensorView* getTv(Val* val) { if (val->isA()) { return val->as(); @@ -471,6 +480,15 @@ class ReplaceExprInput : private kir::ExprMutator { } } + void handle(LoadStoreOp* node) final { + auto replaced_inputs = getMaybeInputReplacementMap(node); + if (replaced_inputs.has_value()) { + auto replacement = IrBuilder::create( + node->opType(), node->out(), node->in()); + registerReplaceWithPredicate(node, replacement); + } + } + private: const std::unordered_map& replacement_map_; }; diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 200c729b191c..8a05ca7f22a9 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -95,6 +95,10 @@ bool derivedFromRootCAAxes(const TensorView* tv, IterDomain* axis); std::unordered_map getParallelDomains( Val* val); +//! Returns true if the expression will be lowered to +//! a ldmatrix intrinsic. +bool isLdMatrixOp(const Expr* expr); + } // namespace ir_utils namespace loop_utils { diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index d6c3ce9342a7..889597d6ff3d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -557,17 +557,13 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) { } } if (has_vectorize_dim) { - auto uop_type = UnaryOpType::Abs; - if (tv->definition() && tv->definition()->isA()) { - uop_type = tv->definition()->as()->getUnaryOpType(); - } TORCH_INTERNAL_ASSERT( tv->definition() == nullptr || (tv->definition()->isA() && - (uop_type == UnaryOpType::Set || - uop_type == UnaryOpType::LD_MATRIX || - uop_type == UnaryOpType::LD_MATRIXT)), - "Vectorized accesses cannot be inline with computation, they are only supported with a Set like operation.", + tv->definition()->as()->getUnaryOpType() == + UnaryOpType::Set) || + tv->definition()->isA(), + "Vectorized accesses cannot be inline with computation, they are only supported with a Set operation.", "TensorView: ", tv); } @@ -966,11 +962,7 @@ void validateTuringMmaInput(TensorView* tv) { tv_def = tv_def->input(0)->definition(); } TORCH_INTERNAL_ASSERT(tv_def); - auto tv_uop = dynamic_cast(tv_def); - TORCH_INTERNAL_ASSERT(tv_uop); - TORCH_INTERNAL_ASSERT( - tv_uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || - tv_uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT); + TORCH_INTERNAL_ASSERT(ir_utils::isLdMatrixOp(tv_def)); } // Output of ldmatrix is swizzled with the mma format, so it @@ -1002,13 +994,12 @@ void validateLdMatrixOutput(TensorView* tv) { } // Checks that the memory ops are supported on the targeted GPU -void validateArchMemoryOp(UnaryOp* uop) { - auto uop_type = uop->getUnaryOpType(); - switch (uop_type) { - case UnaryOpType::LD_MATRIX: - case UnaryOpType::LD_MATRIXT: +void validateArchMemoryOp(LoadStoreOp* ldst) { + switch (ldst->opType()) { + case LoadStoreOpType::LdMatrix: + case LoadStoreOpType::LdMatrixTranspose: validateMinimumArch(7, 5); - validateLdMatrixOutput(uop->out()->as()); + validateLdMatrixOutput(ldst->out()->as()); return; default: return; @@ -1043,8 +1034,8 @@ void validateMma(Fusion* fusion) { break; } } - if (auto uop = dynamic_cast(expr)) { - validateArchMemoryOp(uop); + if (auto ldst = dynamic_cast(expr)) { + validateArchMemoryOp(ldst); } } } diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index b0ba8fed846b..35e49b1dc4df 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -43,7 +43,7 @@ MmaOptions MmaBuilder::build() const { namespace { // Utility to get ldmatrix direction a mma layout and operand -UnaryOpType getLdMatrixType(MmaOptions options) { +LoadStoreOpType getLdMatrixType(MmaOptions options) { bool transpose = false; switch (options.macro) { case MmaOptions::MacroType::Ampere_16_8_16: @@ -57,13 +57,13 @@ UnaryOpType getLdMatrixType(MmaOptions options) { TORCH_INTERNAL_ASSERT(false, "unsupported op with ldmatrix"); break; } - - return transpose ? UnaryOpType::LD_MATRIXT : UnaryOpType::LD_MATRIX; + return transpose ? LoadStoreOpType::LdMatrix + : LoadStoreOpType::LdMatrixTranspose; } } // namespace -UnaryOpType MmaBuilder::ldMatrix() { +LoadStoreOpType MmaBuilder::ldMatrix() { return getLdMatrixType(option_); } diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index f762d683b3a0..ad2ffa5018e1 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -102,7 +102,7 @@ class TORCH_CUDA_CU_API MmaBuilder { MmaBuilder& layout(MmaOptions::MmaInputLayout layout); MmaBuilder& operand(MmaOptions::Operand a_or_b); MmaOptions build() const; - UnaryOpType ldMatrix(); + LoadStoreOpType ldMatrix(); private: MmaOptions option_; diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 5e397a5bfa11..0f5888a29c77 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -255,6 +255,20 @@ void OptOutMutator::mutate(MmaOp* mma) { IrBuilder::create(container, out, in_a, in_b, init, options); } +void OptOutMutator::mutate(LoadStoreOp* ldst) { + Val* out = maybeMutated(ldst->out()); + Val* in = maybeMutated(ldst->in()); + auto op_type = ldst->opType(); + + if (out->sameAs(ldst->out()) && in->sameAs(ldst->in())) { + return; + } + + auto container = ldst->container(); + container->removeExpr(ldst); + auto new_ldst = IrBuilder::create(container, op_type, out, in); +} + void OptOutMutator::mutate(BroadcastOp* bop) { Val* out = maybeMutated(bop->out()); Val* in = maybeMutated(bop->in()); diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.h b/torch/csrc/jit/codegen/cuda/root_domain_map.h index c817701e60bb..7b67b06c0b5f 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.h +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.h @@ -403,6 +403,10 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder mapPointwiseOrReductionOp(wop); } + void handle(LoadStoreOp* ldst) override { + mapPointwiseOrReductionOp(ldst); + } + void handle(MmaOp* wop) override { mapPointwiseOrReductionOp(wop); } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index fe50d1e3b41c..0418aa4cf375 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -289,32 +289,33 @@ void validateResultInnerMN(TensorView* tv, int m, int n) { //! 1. direct output of an ldmatrix op or //! 2. direct output of a broadcast op following a ldmatrix op //! Returns true if the tv is an immediate output of ldmatrix op +//! +//! TODO: this check is bool checkLdMatrixTv(TensorView* tv) { // First check if tv is an ldmatrix output: auto tv_def = tv->definition(); TORCH_CHECK(tv_def != nullptr, "ldmatrix : invalid tv"); - auto tv_def_uop = dynamic_cast(tv_def); bool is_immediate_output = true; - if (tv_def_uop == nullptr) { + if (!ir_utils::isLdMatrixOp(tv_def)) { // Only allow one broadcast in between tv and the ldmatrix op - TORCH_CHECK(tv_def->isA()); + TORCH_CHECK( + tv_def->isA(), + "ldmatrix: only allow serial broadcast between ldmatrix and mma"); tv_def = tv_def->input(0)->definition(); TORCH_CHECK(tv_def != nullptr, "ldmatrix : invalid tv"); - tv_def_uop = dynamic_cast(tv_def); - is_immediate_output = false; } - - TORCH_CHECK(tv_def_uop != nullptr, "ldmatrix : invalid op"); + TORCH_CHECK(ir_utils::isLdMatrixOp(tv_def), "ldmatrix : invalid op type"); TORCH_CHECK( - tv_def_uop->getUnaryOpType() == UnaryOpType::LD_MATRIX || - tv_def_uop->getUnaryOpType() == UnaryOpType::LD_MATRIXT, - "ldmatrix : invalid op type"); - TORCH_CHECK(tv->nDims() > 2); - TORCH_CHECK(!tv->axis(-1)->isBroadcast()); - TORCH_CHECK(!tv->axis(-1)->isReduction()); - TORCH_CHECK(!tv->axis(-2)->isBroadcast()); - TORCH_CHECK(!tv->axis(-2)->isReduction()); - + tv->nDims() > 2, + "ldmatrix: scheduled tv needs to be more than 2 dimensional"); + TORCH_CHECK( + !tv->axis(-1)->isBroadcast(), "ldmatrix: unsupported scheduled axes"); + TORCH_CHECK( + !tv->axis(-1)->isReduction(), "ldmatrix: unsupported scheduled axes"); + TORCH_CHECK( + !tv->axis(-2)->isBroadcast(), "ldmatrix: unsupported scheduled axes"); + TORCH_CHECK( + !tv->axis(-2)->isReduction(), "ldmatrix: unsupported scheduled axes"); return is_immediate_output; } diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index d278b1585052..2b54c10367c2 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -775,7 +775,7 @@ WelfordResult TensorView::rFactor( return WelfordResult(producer_avg, producer_var, producer_n); } -TensorView* TensorView::cacheBefore(UnaryOpType cache_op) { +TensorView* TensorView::cacheBefore(c10::optional cache_op) { TORCH_INTERNAL_ASSERT( !container()->isA(), "Function invalid for kernel container."); @@ -852,7 +852,13 @@ TensorView* TensorView::cacheBefore(UnaryOpType cache_op) { ir_utils::replaceValInExpr(definition(), this, producer); // Expr* producer_uses = - IrBuilder::create(container(), cache_op, consumer, producer); + if (cache_op.has_value()) { + IrBuilder::create( + container(), cache_op.value(), consumer, producer); + } else { + IrBuilder::create( + container(), UnaryOpType::Set, consumer, producer); + } // definition_ is no longer valid // setDefinition(nullptr); @@ -911,7 +917,7 @@ TensorView* TensorView::cacheFork() { return new_output; } -TensorView* TensorView::cacheAfter(UnaryOpType cache_op) { +TensorView* TensorView::cacheAfter(c10::optional cache_op) { TORCH_INTERNAL_ASSERT( !container()->isA(), "Function invalid for kernel container."); @@ -977,7 +983,13 @@ TensorView* TensorView::cacheAfter(UnaryOpType cache_op) { } // Expr* consumer_definition = - IrBuilder::create(container(), cache_op, consumer, producer); + if (cache_op.has_value()) { + IrBuilder::create( + container(), cache_op.value(), consumer, producer); + } else { + IrBuilder::create( + container(), UnaryOpType::Set, consumer, producer); + } return consumer; } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index cce768e5e6e8..15d3d64be694 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -1629,19 +1629,19 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // Gemm 1 main loop read auto tv0cw = tv0r->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(UnaryOpType::LD_MATRIX); + auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); auto tv1cw = tv1r->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(UnaryOpType::LD_MATRIX); + auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); // Gemm 1 accumulator reg auto tv3c = tv3->cacheBefore(); // Gemm 2 main loop read auto tv3cw = tv3h->cacheAfter(); - auto tv3cr = tv3cw->cacheAfter(UnaryOpType::LD_MATRIX); + auto tv3cr = tv3cw->cacheAfter(LoadStoreOpType::LdMatrix); auto tv2cw = tv2r->cacheAfter(); - auto tv2cr = tv2cw->cacheAfter(UnaryOpType::LD_MATRIX); + auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix); // Gemm 2 accumulator reg auto tv4c = tv4->cacheBefore(); @@ -1928,9 +1928,9 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // Gemm 1 main loop read auto tv0cw = tv0r->cacheAfter(); - auto tv0cr = tv0cw->cacheAfter(UnaryOpType::LD_MATRIX); + auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); auto tv1cw = tv1r->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(UnaryOpType::LD_MATRIX); + auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); // Gemm 1 accumulator reg auto tv3c = tv3->cacheBefore(); @@ -1942,10 +1942,10 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // Gemm 2 main loop read // auto tv3cw = tv3h->cacheAfter(); - auto tv3cr = tv3h->cacheAfter(UnaryOpType::LD_MATRIX); + auto tv3cr = tv3h->cacheAfter(LoadStoreOpType::LdMatrix); auto tv2cw = tv2r->cacheAfter(); - auto tv2cr = tv2cw->cacheAfter(UnaryOpType::LD_MATRIX); + auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix); // Gemm 2 accumulator reg auto tv4c = tv4->cacheBefore(); diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index b5214151f219..03bad25de03e 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -208,6 +208,8 @@ static const char* expr_type2string(ExprType t) { return "BroadcastOp"; case ExprType::WelfordOp: return "WelfordOp"; + case ExprType::LoadStoreOp: + return "LoadStoreOp"; case ExprType::MmaOp: return "MmaOp"; case ExprType::TransposeOp: @@ -344,10 +346,6 @@ static const char* unary_op_type2string(UnaryOpType t) { return "tanh"; case UnaryOpType::Trunc: return "trunc"; - case UnaryOpType::LD_MATRIX: - return "ldmatrix"; - case UnaryOpType::LD_MATRIXT: - return "ldmatrixt"; default: TORCH_INTERNAL_ASSERT(false, "No string found for unary op type."); } @@ -632,6 +630,17 @@ static const char* thread_size2string(ParallelType t) { } } +static const char* load_store_type2string(LoadStoreOpType t) { + switch (t) { + case LoadStoreOpType::LdMatrix: + return "LdMatrix"; + case LoadStoreOpType::LdMatrixTranspose: + return "LdMatrixTranspose"; + default: + TORCH_INTERNAL_ASSERT(false, "Unexpected parallel type"); + } +} + const unsigned int _WORD_SHIFT = 16; constexpr unsigned int supported_switch_pair(DataType t1, DataType t2) { return ((unsigned int)t1 << _WORD_SHIFT) + (unsigned int)t2; @@ -796,6 +805,12 @@ std::ostream& operator<<(std::ostream& out, const IdMappingMode immtype) { return out << id_map_mode_type2string(immtype); } +std::ostream& operator<<( + std::ostream& out, + const LoadStoreOpType load_store_type) { + return out << load_store_type2string(load_store_type); +} + TORCH_CUDA_CU_API std::ostream& operator<<( std::ostream& out, const IterType bt) { diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index d6398a225e2a..31f997ab7469 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -96,6 +96,7 @@ enum class ExprType { GatherOp, ViewDtypeOp, ViewOp, + LoadStoreOp, Split, Merge, Allocate, @@ -152,11 +153,7 @@ enum class UnaryOpType { Trunc, // Might be a bitwise operator or boolean operator. - Not, - - // Memory ops, - LD_MATRIX, - LD_MATRIXT + Not }; // Primarily for Not, which could be Not a boolean, or a bitwise not. @@ -276,6 +273,8 @@ static constexpr std::array kIdMappingModes = { IdMappingMode::EXACT, IdMappingMode::LOOP}; +enum class LoadStoreOpType { LdMatrix, LdMatrixTranspose }; + // Returns if function needs an f suffix on the operator when operating on a // float value i.e. sin->sinf bool needFloatSuffix(UnaryOpType t); @@ -299,6 +298,9 @@ TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ParallelType); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const MemoryType); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const IterType); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const IdMappingMode); +TORCH_CUDA_CU_API std::ostream& operator<<( + std::ostream&, + const LoadStoreOpType); std::string stringifyBooleanOp(const UnaryOpType); std::string stringifyBooleanOp(const BinaryOpType); From d2274cd1d19978350ad2103256bc689dda0b0f46 Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 26 Apr 2022 12:00:01 -0700 Subject: [PATCH 28/57] code refactor bug fix --- torch/csrc/jit/codegen/cuda/mma_type.cpp | 4 ++-- torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index 35e49b1dc4df..997975efa906 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -57,8 +57,8 @@ LoadStoreOpType getLdMatrixType(MmaOptions options) { TORCH_INTERNAL_ASSERT(false, "unsupported op with ldmatrix"); break; } - return transpose ? LoadStoreOpType::LdMatrix - : LoadStoreOpType::LdMatrixTranspose; + return transpose ? LoadStoreOpType::LdMatrixTranspose + : LoadStoreOpType::LdMatrix; } } // namespace diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 0418aa4cf375..bbc1f4554ea3 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -303,6 +303,7 @@ bool checkLdMatrixTv(TensorView* tv) { "ldmatrix: only allow serial broadcast between ldmatrix and mma"); tv_def = tv_def->input(0)->definition(); TORCH_CHECK(tv_def != nullptr, "ldmatrix : invalid tv"); + is_immediate_output = false; } TORCH_CHECK(ir_utils::isLdMatrixOp(tv_def), "ldmatrix : invalid op type"); TORCH_CHECK( From 090752d5da91b14efa17f269ebbe51e1a7e78498 Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 26 Apr 2022 12:13:11 -0700 Subject: [PATCH 29/57] update test for shared mem initialization --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 7996eb13a8df..668624b1d9d4 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -22305,7 +22305,9 @@ TEST_F(NVFuserTest, FusionRAWSyncInsertionPlace4_CUDA) { // Record number of unary ops that modifies shared memory. if (uop->out()->isA() && uop->out()->as()->view()->getMemoryType() == - MemoryType::Shared) { + MemoryType::Shared && + // Filter out initialization expressions + uop->in()->isA()) { number_of_writes_++; } } From 640d3aa183f0e9568601ece4881f0127f48dbf2b Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 26 Apr 2022 16:45:51 -0700 Subject: [PATCH 30/57] refactor mma user interface --- .../jit/codegen/cuda/ir_interface_nodes.h | 11 ----- torch/csrc/jit/codegen/cuda/mma_type.cpp | 14 ++++++- torch/csrc/jit/codegen/cuda/mma_type.h | 40 ++++++++++++++++++- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 7 ---- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 33 +++++++-------- 5 files changed, 68 insertions(+), 37 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 45873a9fdac0..acd9e9281707 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -448,17 +448,6 @@ class TORCH_CUDA_CU_API TensorView : public Val { return is_double_buffered_; } - //! Fill in mma options in scheduling time. - //! Each mma op in Fusion IR must be configured once before lowering. - //! Mma options are configuration parameters used in lowering to mma - //! instrinsics, mainly the type of mma macro to use and input data layout - //! etc. - //! - //! TODO: This step will very likely be removed in a follow up PR. All of - //! the options configured here could actually be inferred from fusion IR - //! once we are feature complete. - void configureMma(MmaOptions options); - //! Transforms the innermost iterdomains according to the given mma swizzle, //! this should be used on the tvs that are either inputs/outputs of an //! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index 997975efa906..bf9347d69453 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -1,3 +1,5 @@ +#include +#include #include namespace torch { @@ -40,6 +42,16 @@ MmaOptions MmaBuilder::build() const { return option_; } +void MmaBuilder::configureMma(TensorView* mma_output) const { + TORCH_CHECK( + mma_output->definition(), + "configureMma: invalid for input tensor ", + mma_output); + auto mma = dynamic_cast(mma_output->definition()); + TORCH_CHECK(mma, "configureMma: invalid for non-mma output: ", mma_output); + mma->configureOptions(option_); +} + namespace { // Utility to get ldmatrix direction a mma layout and operand @@ -63,7 +75,7 @@ LoadStoreOpType getLdMatrixType(MmaOptions options) { } // namespace -LoadStoreOpType MmaBuilder::ldMatrix() { +LoadStoreOpType MmaBuilder::ldMatrix() const { return getLdMatrixType(option_); } diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index ad2ffa5018e1..30d7d2e34f23 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -95,14 +95,50 @@ struct MmaOptions { } }; -//! User interface generating mma options for mma op +//! User interface for configuring the mma and mma related +//! operators by specifying the mma instruction tile type +//! input data layout, and the operand position of a tensor. class TORCH_CUDA_CU_API MmaBuilder { public: + //! Initialized a mma builder, for the given mma instruction type. + //! TODO: the mma implementation is generic and should not have + //! strong dependency on the actual matmul tiling shapes. The + //! MatMulTileOptions provided in here is a WAR for mma format and + //! should be removed once there is support for labeling swizzles + //! on iterdomains. MmaBuilder(MmaOptions::MacroType macro, MatMulTileOptions gemm_tile); + + //! User configuration function: + //! Specifies the input matrix layout for the mma instruction. + //! see [Operand Layout Convention]. MmaBuilder& layout(MmaOptions::MmaInputLayout layout); + + //! User configuration function: + //! Specifies which element in the mma op this builder is generating + //! parameters for, i.e. A or B. This is useful when generating + //! data swizzles for different elements of mma. + //! - Operand::NotOperand means the parameters describe accumulator in mma + //! op. + //! - This option is ignored when configuring the mma operator itself. MmaBuilder& operand(MmaOptions::Operand a_or_b); + + //! Generates the matching ldmatrix instruction type for the + //! specified mma option. + LoadStoreOpType ldMatrix() const; + + //! Fill in mma options in scheduling time. + //! Each mma op in Fusion IR must be configured once before lowering. + //! Mma options are configuration parameters used in lowering to mma + //! instrinsics, mainly the type of mma macro to use and input data layout + //! etc. + //! + //! TODO: This step will very likely be removed in a follow up PR. All of + //! the options configured here could actually be inferred from fusion IR + //! once we are feature complete. + void configureMma(TensorView* mma_output) const; + + //! Export all the parameters with user's configurations applied. MmaOptions build() const; - LoadStoreOpType ldMatrix(); private: MmaOptions option_; diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 4d709d42ac84..953cacb17103 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -1111,13 +1111,6 @@ TensorView* TensorViewBuilder::build() const { IrBuilder::create(domain, contiguity_), dtype_); } -void TensorView::configureMma(MmaOptions options) { - TORCH_CHECK(definition(), "configureMma: invalid for input tensor ", this); - auto mma = dynamic_cast(definition()); - TORCH_CHECK(mma, "configureMma: invalid for non-mma output: ", this); - mma->configureOptions(options); -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 15d3d64be694..7c9fbe7787ac 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -140,9 +140,10 @@ TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { gemm_tile.cta_tile = GemmTile(16, 16, 4); gemm_tile.warp_tile = GemmTile(16, 16, 4); gemm_tile.instruction_tile = GemmTile(16, 16, 4); + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) .layout(MmaOptions::MmaInputLayout::TT); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); // Write A to smem auto tv0cw = tv0b->cacheAfter(); @@ -237,7 +238,7 @@ TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) .layout(MmaOptions::MmaInputLayout::TN); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0cw = tv0b->cacheAfter(); auto tv0cr = tv0cw->cacheAfter(); @@ -298,7 +299,7 @@ TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) .layout(MmaOptions::MmaInputLayout::NT); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0cw = tv0b->cacheAfter(); auto tv0cr = tv0cw->cacheAfter(); @@ -370,7 +371,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTT_CUDA) { auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) .layout(MmaOptions::MmaInputLayout::TT); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0r = tv0->cacheAfter(); auto tv1r = tv1->cacheAfter(); @@ -620,7 +621,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) .layout(MmaOptions::MmaInputLayout::TN); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0r = tv0->cacheAfter(); auto tv1r = tv1->cacheAfter(); @@ -769,7 +770,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) .layout(MmaOptions::MmaInputLayout::NT); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0r = tv0->cacheAfter(); auto tv1r = tv1->cacheAfter(); @@ -923,7 +924,7 @@ TEST_F(NVFuserTest, FusionAmpereMMATN_CUDA) { MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TN); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0cw = tv0b->cacheAfter(); auto tv0cr = @@ -990,7 +991,7 @@ TEST_F(NVFuserTest, FusionAmpereMMATT_CUDA) { MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TT); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0cw = tv0b->cacheAfter(); auto tv0cr = @@ -1062,7 +1063,7 @@ TEST_F(NVFuserTest, FusionAmpereMMANT_CUDA) { MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::NT); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0cw = tv0b->cacheAfter(); auto tv0cr = @@ -1142,7 +1143,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTN_CUDA) { MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TN); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0r = tv0->cacheAfter(); auto tv1r = tv1->cacheAfter(); @@ -1290,7 +1291,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTT_CUDA) { MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TT); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0r = tv0->cacheAfter(); auto tv1r = tv1->cacheAfter(); @@ -1437,7 +1438,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulNT_CUDA) { MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::NT); - tv2->configureMma(mma_builder.build()); + mma_builder.configureMma(tv2); auto tv0r = tv0->cacheAfter(); auto tv1r = tv1->cacheAfter(); @@ -1617,8 +1618,8 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile2) .layout(MmaOptions::MmaInputLayout::TN); - tv3->configureMma(mma_builder1.build()); - tv4->configureMma(mma_builder2.build()); + mma_builder1.configureMma(tv3); + mma_builder2.configureMma(tv4); // Global read for gemm 1 auto tv0r = tv0->cacheAfter(); @@ -1916,8 +1917,8 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) .layout(MmaOptions::MmaInputLayout::TN); - tv3->configureMma(mma_builder1.build()); - tv4->configureMma(mma_builder2.build()); + mma_builder1.configureMma(tv3); + mma_builder2.configureMma(tv4); // Global read for gemm 1 auto tv0r = inp->cacheAfter(); From da40702f1bf93a88c3406a31df09e0677071c218 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 2 May 2022 14:15:03 -0700 Subject: [PATCH 31/57] add turing mma support and test --- .../jit/codegen/cuda/lower_validation.cpp | 8 + torch/csrc/jit/codegen/cuda/mma_type.cpp | 8 +- torch/csrc/jit/codegen/cuda/mma_type.h | 1 + .../jit/codegen/cuda/runtime/tensorcore.cu | 68 ++ .../jit/codegen/cuda/scheduler/mma_utils.cpp | 5 +- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 665 ++++++++++++++++++ 6 files changed, 753 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 889597d6ff3d..190f76969225 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -1021,6 +1021,14 @@ void validateMma(Fusion* fusion) { case MmaOptions::MacroType::Volta_16_16_4: validateMinimumArch(7, 0); break; + case MmaOptions::MacroType::Turing_16_8_16: + validateMinimumArch(7, 5); + + // Check that operands come from ldmatrix, can be + // relaxed once swizzles can be labeled on iterdomains. + validateTuringMmaInput(mma->inA()->as()); + validateTuringMmaInput(mma->inB()->as()); + break; case MmaOptions::MacroType::Ampere_16_8_16: validateMinimumArch(8, 0); diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index bf9347d69453..9ac067eb8b77 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -18,6 +18,7 @@ MmaBuilder::MmaBuilder( case MmaOptions::MacroType::Volta_16_16_4: option_.accumulator_stride = outer_stride * 4; break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: option_.accumulator_stride = outer_stride * 2; break; @@ -58,6 +59,7 @@ namespace { LoadStoreOpType getLdMatrixType(MmaOptions options) { bool transpose = false; switch (options.macro) { + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: // Turing mma assumes TN as default transpose = (options.operand == MmaOptions::Operand::A && @@ -84,7 +86,7 @@ bool isVolta(MmaOptions::MacroType macro) { } bool isTuring(MmaOptions::MacroType macro) { - return false; + return macro == MmaOptions::MacroType::Turing_16_8_16; } bool isAmpere(MmaOptions::MacroType macro) { @@ -96,6 +98,7 @@ int getOutputRegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 8; break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: return 4; break; @@ -111,6 +114,7 @@ int getInputARegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 4; break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: return 8; break; @@ -126,6 +130,7 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) { case MmaOptions::MacroType::Volta_16_16_4: return 4; break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: return 4; default: @@ -176,6 +181,7 @@ std::string toString(MmaOptions::MacroType mt) { case MmaOptions::MacroType::Volta_16_16_4: ss << "M16N16K4"; break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: ss << "M16N8K16"; break; diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index 30d7d2e34f23..610d50c233fa 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -58,6 +58,7 @@ struct MmaOptions { NoMMA = 0, Volta_16_16_4, Ampere_16_8_16, + Turing_16_8_16, Ampere_16_8_8 // place holder for tf32 }; diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu index 01d27b343947..c1d842000837 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu @@ -214,6 +214,74 @@ DEVICE_INLINE void initM16N16K4NT(Array* accumulator) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +namespace Turing { + +namespace util { +// MMA instruction wrappers (sm_75+): +DEVICE_INLINE void m16n8k16TN( + Array* C, + Array<__half, 8, 8>* A, + Array<__half, 4, 4>* B) { + unsigned const* _A = reinterpret_cast(A); + unsigned const* _B = reinterpret_cast(B); + unsigned* _C = reinterpret_cast(C); + const unsigned* _D = reinterpret_cast(C); + + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3]) + : "r"(_A[0]), + "r"(_A[1]), + "r"(_B[0]), + "r"(_D[0]), + "r"(_D[1]), + "r"(_D[2]), + "r"(_D[3])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3]) + : "r"(_A[2]), + "r"(_A[3]), + "r"(_B[1]), + "r"(_D[0]), + "r"(_D[1]), + "r"(_D[2]), + "r"(_D[3])); +} + +} // namespace util + +template +DEVICE_INLINE void initM16N8K16TN(Array* accumulator) { + float* _C = reinterpret_cast(accumulator); + _C[0] = 0; + _C[1] = 0; + _C[acc_stride] = 0; + _C[acc_stride + 1] = 0; +} + +template +DEVICE_INLINE void M16N8K16TN( + Array* C, + Array<__half, 8, 8>* A, + Array<__half, 4, 4>* B) { + // TODO: in a follow up, + // lift this fused swizzle onto iterdomain + float* _C = reinterpret_cast(C); + float C_data[4] = {_C[0], _C[1], _C[acc_stride], _C[acc_stride + 1]}; + + util::m16n8k16TN(reinterpret_cast*>(&C_data[0]), A, B); + + _C[0] = C_data[0]; + _C[1] = C_data[1]; + _C[acc_stride] = C_data[2]; + _C[acc_stride + 1] = C_data[3]; +} + +} // namespace Turing + +#endif // Arch 75 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) + namespace Ampere { namespace util { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index bbc1f4554ea3..19e449f93116 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -220,6 +220,7 @@ void WarpMmaSwizzler::scheduleMmaWarpOutput( setWarpMapped(tv, 5); } break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: scheduleTuringM16N8K16MmaWarpOutput(tv, options); if (tv->definition()->isA()) { @@ -241,6 +242,7 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) { case MmaOptions::MacroType::Volta_16_16_4: scheduleVoltaOperandRead(tv, options); break; + case MmaOptions::MacroType::Turing_16_8_16: case MmaOptions::MacroType::Ampere_16_8_16: scheduleTuringOperandRead(tv, options); break; @@ -401,7 +403,8 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { : isOperandTransposed(options); // Check mma option is supported TORCH_CHECK( - options.macro == MmaOptions::MacroType::Ampere_16_8_16, + options.macro == MmaOptions::MacroType::Ampere_16_8_16 || + options.macro == MmaOptions::MacroType::Turing_16_8_16, "scheduleLdMatrix: unknown macro for ldmatrix"); if (options.operand == MmaOptions::Operand::A) { diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 7c9fbe7787ac..348f0bb1e9db 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -2210,6 +2210,671 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001)); } +// MMA unit test on Turing +TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + + // [M,K] + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); + // [N,K] + auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cacheBefore(); + + // [M,N,K] -> [N,M,K] + tv0cr->reorder({{-2, -3}, {-3, -2}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({8, 16}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// MMA unit test on Turing +TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + + // [M,K] + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); + // [K,N] + auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,K,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {true, false, false}); + + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TT); + + mma_builder.configureMma(tv2); + + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cacheBefore(); + + // [M,K,N] -> [N,M,K] + tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [M,K,N] -> [M,N,K] + tv1cr->reorder({{-2, -1}, {-1, -2}}); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // [M,K,N] -> [M,N,K] + tv2c->reorder({{-2, -1}, {-1, -2}}); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({16, 8}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// MMA unit test on Turing +TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + + // [K,M] + auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); + // [K,N] + auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [K,M,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, true, false}); + auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(16, 8, 16); + gemm_tile.warp_tile = GemmTile(16, 8, 16); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::NT); + + mma_builder.configureMma(tv2); + + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + + auto tv2c = tv2->cacheBefore(); + + // [K,M,N] -> [N,M,K] + tv0cr->reorder({{-3, -1}, {-1, -3}}); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [K,M,N] -> [M,N,K] + tv1cr->reorder({ + {-3, -1}, + {-2, -3}, + {-1, -2}, + }); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // [K,M,N] -> [M,N,K] + tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + tv0cw->setMemoryType(MemoryType::Shared); + tv1cw->setMemoryType(MemoryType::Shared); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({16, 16}, options); + auto t1 = at::randn({16, 8}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + + testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +// Matmul test on Turing +TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 511, N = 257, K = 88; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cacheBefore(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Matmul test on Turing +TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K = 128; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,K,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TT); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cacheBefore(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] -> [No,Ko,K,N] + tv1cw->reorder({{-2, -1}, {-1, -2}}); + tv1r->reorder({{-2, -1}, {-1, -2}}); + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Matmul test on Turing +TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 512, N = 256, K = 128; + + // [K,M] + auto tv0 = makeContigTensor(2, DataType::Half); + // [K,N] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [K,M,N] + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, true, false}); + + auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::NT); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cacheBefore(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] -> [..., K,M] + tv0cw->reorder({{-2, -1}, {-1, -2}}); + tv0r->reorder({{-2, -1}, {-1, -2}}); + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] -> [No,Ko,K,N] + tv1cw->reorder({{-2, -1}, {-1, -2}}); + tv1r->reorder({{-2, -1}, {-1, -2}}); + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({K, M}, options); + auto t1 = at::randn({K, N}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace jit From 36786b9121ca390ba6b94d4f02237f643277a9a4 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 2 May 2022 15:47:58 -0700 Subject: [PATCH 32/57] add shared mem in zero leaf detection --- .../cuda/lower_predicate_elimination.cpp | 51 ++++++++++++++----- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 257ecc3c4080..5cd5630bc2b6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -292,6 +292,43 @@ class PredicateChcker : public IterVisitor { return false; } + // Utility to find the leaf iterdomains of the given + // tensor view that will be treated as "zero loops" + // in the indexing pass. + std::vector getZeroLeafIds(const TensorView* tv) const{ + TORCH_INTERNAL_ASSERT( + output->getMemoryType() == MemoryType::Local || + output->getMemoryType() == MemoryType::Shared, + "Local or shared memory tensor is assumed: ", + output->toString()); + bool is_shared_mem = tv->getMemoryType() == MemoryType::Shared; + std::vector zero_leaf_ids; + for (const auto i : c10::irange(tv->nDims())) { + auto leaf_id = tv->axis(i); + if(is_shared_mem && leaf_id->isThread()){ + // Thread parallel axes on shared mem are never + // zero loops as each thread owns its share + // of the shared mem space. + continue; + } + if ( + // Non-thread parallel dimension on the left + // of CA axes are zero loops. + i < tv->getComputeAtPosition() || + // Parallel axes on local mem is zero loop. + // FIXME: check un-mapped case. + leaf_id->isThread() || + // Mma axes, similar to vectorization, are + // implicit in hardware intrinsics, and thus + // will be treated as a zero loop. + leaf_id->isMma()) { + zero_leaf_ids.push_back(leaf_id); + } + } + + return zero_leaf_ids; + } + // An index can exceed the logical extent of the indexed domain if // it's split. It can cause a reduction op to reduce the same value // multiple times. Even a pointwise op can be a problem if the @@ -340,19 +377,7 @@ class PredicateChcker : public IterVisitor { if (split_root.empty()) { continue; } - TORCH_INTERNAL_ASSERT( - output->getMemoryType() == MemoryType::Local || - output->getMemoryType() == MemoryType::Shared, - "Local memory tensor is assumed: ", - output->toString()); - std::vector zero_leaf_ids; - for (const auto i : c10::irange(output->nDims())) { - auto leaf_id = output->axis(i); - if (i < output->getComputeAtPosition() || leaf_id->isThread() || - leaf_id->isMma()) { - zero_leaf_ids.push_back(leaf_id); - } - } + const auto zero_leaf_ids = getZeroLeafIds(output); if (zero_leaf_ids.empty()) { return true; } From 2edd1cc19a3eba1fb844a7f06d361656d1ca2597 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 2 May 2022 15:57:53 -0700 Subject: [PATCH 33/57] move shared mem predicate --- .../cuda/lower_predicate_elimination.cpp | 115 ++++++++++-------- 1 file changed, 61 insertions(+), 54 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 5cd5630bc2b6..18d95457cd04 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -62,51 +62,6 @@ class PredicateAnalyzer : public OptOutDispatch { return true; } - // This is initial step to gradually remove predicates around - // sharedmem access in suitable situations. - // Using an additional variable to track the predicate-on reasons - // when the predicate around shared mem cannot be removed. - bool needs_sharedmem_addr_pred = false; - if (producer->getMemoryType() == MemoryType::Shared || - consumer->getMemoryType() == MemoryType::Shared) { - // Indexing is based on consumer leaf ids so check the consumer. - auto& parallel_dimension_map = - GpuLower::current()->parallelDimensionMap(); - for (auto id : consumer->domain()->domain()) { - if (id->isThreadDim()) { - auto ptype = id->getParallelType(); - if (!parallel_dimension_map.isExact(ptype)) { - needs_sharedmem_addr_pred = true; - } - } - - // TODO: (Enable in a follow up) - // smem predicate removal with init would break unroll and unswitch, - // eg. as in issue 1133, so disabling this removal pattern for now. - if (id->getParallelType() == ParallelType::Unroll || - id->getParallelType() == ParallelType::Unswitch) { - needs_sharedmem_addr_pred = true; - } - - // TODO: (Enable in a follow up) - // This cannot yet be removed since smem initialization needs to be - // handled specially, e.g. as in smem_reduce test. Will be able to - // lift this one once the generic pred removal pass with fusion - // traversal is ready. - auto consumer_def = consumer->definition(); - if (consumer_def->isA() || - consumer_def->isA()) { - if (producer->getMemoryType() == MemoryType::Shared) { - needs_sharedmem_addr_pred = true; - } - } - } - } - - if (needs_sharedmem_addr_pred) { - return true; - } - auto pairwise_map = PairwiseRootDomainMap(producer, consumer); auto c2p = BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) @@ -218,7 +173,7 @@ class PredicateChcker : public IterVisitor { void handle(Expr* expr) final { needs_predicate_ = predicateIntDiv(expr) || predicateMisalignedVectorize(expr) || predicateShift(expr) || - predicateProducerConsumerPair(expr) || + predicateSharedMemAccess(expr) || predicateProducerConsumerPair(expr) || predicateNonDivisibleRootDomains(expr) || predicateNonDivisibleSplit(expr); @@ -292,20 +247,72 @@ class PredicateChcker : public IterVisitor { return false; } + bool predicateSharedMemAccess(Expr* expr) const { + // This is initial step to gradually remove predicates around + // sharedmem access in suitable situations. + // Using an additional variable to track the predicate-on reasons + // when the predicate around shared mem cannot be removed. + for (auto consumer : ir_utils::filterByType(expr->outputs())) { + for (auto producer : ir_utils::filterByType(expr->inputs())) { + if (producer->getMemoryType() == MemoryType::Shared || + consumer->getMemoryType() == MemoryType::Shared) { + if (needSharedMemPredicate(producer, consumer)) { + return true; + } + } + } + } + return false; + } + + bool needSharedMemPredicate(TensorView* producer, TensorView* consumer) + const { + // Indexing is based on consumer leaf ids so check the consumer. + auto& parallel_dimension_map = GpuLower::current()->parallelDimensionMap(); + for (auto id : consumer->domain()->domain()) { + if (id->isThreadDim()) { + auto ptype = id->getParallelType(); + if (!parallel_dimension_map.isExact(ptype)) { + return true; + } + } + + // TODO: (Enable in a follow up) + // smem predicate removal with init would break unroll and unswitch, + // eg. as in issue 1133, so disabling this removal pattern for now. + if (id->getParallelType() == ParallelType::Unroll || + id->getParallelType() == ParallelType::Unswitch) { + return true; + } + + // TODO: (Enable in a follow up) + // This cannot yet be removed since smem initialization needs to be + // handled specially, e.g. as in smem_reduce test. Will be able to + // lift this one once the generic pred removal pass with fusion + // traversal is ready. + auto consumer_def = consumer->definition(); + if (consumer_def->isA() || consumer_def->isA()) { + if (producer->getMemoryType() == MemoryType::Shared) { + return true; + } + } + } + } + // Utility to find the leaf iterdomains of the given // tensor view that will be treated as "zero loops" // in the indexing pass. - std::vector getZeroLeafIds(const TensorView* tv) const{ + std::vector getZeroLeafIds(const TensorView* tv) const { TORCH_INTERNAL_ASSERT( - output->getMemoryType() == MemoryType::Local || - output->getMemoryType() == MemoryType::Shared, - "Local or shared memory tensor is assumed: ", - output->toString()); + output->getMemoryType() == MemoryType::Local || + output->getMemoryType() == MemoryType::Shared, + "Local or shared memory tensor is assumed: ", + output->toString()); bool is_shared_mem = tv->getMemoryType() == MemoryType::Shared; std::vector zero_leaf_ids; for (const auto i : c10::irange(tv->nDims())) { auto leaf_id = tv->axis(i); - if(is_shared_mem && leaf_id->isThread()){ + if (is_shared_mem && leaf_id->isThread()) { // Thread parallel axes on shared mem are never // zero loops as each thread owns its share // of the shared mem space. @@ -314,11 +321,11 @@ class PredicateChcker : public IterVisitor { if ( // Non-thread parallel dimension on the left // of CA axes are zero loops. - i < tv->getComputeAtPosition() || + i < tv->getComputeAtPosition() || // Parallel axes on local mem is zero loop. // FIXME: check un-mapped case. leaf_id->isThread() || - // Mma axes, similar to vectorization, are + // Mma axes, similar to vectorization, are // implicit in hardware intrinsics, and thus // will be treated as a zero loop. leaf_id->isMma()) { From bcb18af067c4ecb084aef819bd1855e0d701bf47 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 2 May 2022 16:00:32 -0700 Subject: [PATCH 34/57] comment --- torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 18d95457cd04..cacd5d16123c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -265,6 +265,8 @@ class PredicateChcker : public IterVisitor { return false; } + // Check for conditions where the predicate cannot be removed + // when either producer or consumer is in shared memory. bool needSharedMemPredicate(TensorView* producer, TensorView* consumer) const { // Indexing is based on consumer leaf ids so check the consumer. @@ -272,6 +274,8 @@ class PredicateChcker : public IterVisitor { for (auto id : consumer->domain()->domain()) { if (id->isThreadDim()) { auto ptype = id->getParallelType(); + // Need to predicate to avoid out of bound access + // because of over-subscribed block size. if (!parallel_dimension_map.isExact(ptype)) { return true; } From fd5e178ed26e9e8e6a130cbf5de291212b65fa2b Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 3 May 2022 10:05:46 -0700 Subject: [PATCH 35/57] adjust unused ldmatrix address for Turing --- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 060f2920b0e3..90ffb27a63eb 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -21,6 +21,37 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) { return smem_ptr_uint; } +// LdMatrix has .x1, .x2 and .x4 options, currently we actively use .x2 and +// .x4. In .x2 option. the the address register of upper half warp (lane 16-31) +// are un-used but on Turing [sm75,sm80) architecture these un-used address +// need to be valid, in the sense that: +// 1. The data it points to has to be within allocated shared mem buffer. +// 2. The address needs to be aligned to 16 byte. +// This function addresses 2. above by masking out the sub-16B component +// of the address in upper warp and 1. is guaranteed by ldmatrix swizzle +// util. +// This will not affect any functionality since the adjusted address +// are not used. This is just modification to satisfy the address +// requirement on Turing. The requirement is lifted in sm80+, so +// this function is a no-op on Ampere or above. +DEVICE_INLINE void adjustPartialLdMatrixAddrInTuring(unsigned& addr_in_byte) { +#if (__CUDA_ARCH__ < 800) + unsigned thread_id = threadIdx.x; + // Upper half warp has 8 bytes offset from aligned in .x2 option + // of ldmatrix. Currently no support for .x1 so assume always + // adjust by half warp. + const unsigned half_warp = 16; + // Need to adjust to 16 byte alignment, mask out un-aligned component. + const unsigned mask_out = 16 - 1; + // Adjust only in upper half warp. + // use bit math to reduce strength + if (thread_id & half_warp) { + // mask out the bits where adjust_mask has 1. + addr_in_byte &= (~mask_out); + } +#endif //(__CUDA_ARCH__ < 800) +} + } // namespace util // Load Matrix (per warp instruction) is to take data from SMEM to Local Memory. @@ -36,6 +67,7 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) { DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) { uint2& val = reinterpret_cast(out); unsigned addr = util::toSmem(ptr); + util::adjustPartialLdMatrixAddrInTuring(addr); asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];" : "=r"(val.x), "=r"(val.y) : "r"(addr)); @@ -47,6 +79,7 @@ DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) { DEVICE_INLINE void ldMatrixT(Array<__half, 4, 4>& out, void const* ptr) { uint2& val = reinterpret_cast(out); unsigned addr = util::toSmem(ptr); + util::adjustPartialLdMatrixAddrInTuring(addr); asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];" : "=r"(val.x), "=r"(val.y) : "r"(addr)); From 86f2d612b5fca0cf910bc2d6775e1faef8ca85e2 Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 3 May 2022 10:06:25 -0700 Subject: [PATCH 36/57] cleanup --- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 90ffb27a63eb..f511501ff11c 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -36,7 +36,7 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) { // this function is a no-op on Ampere or above. DEVICE_INLINE void adjustPartialLdMatrixAddrInTuring(unsigned& addr_in_byte) { #if (__CUDA_ARCH__ < 800) - unsigned thread_id = threadIdx.x; + const unsigned thread_id = threadIdx.x; // Upper half warp has 8 bytes offset from aligned in .x2 option // of ldmatrix. Currently no support for .x1 so assume always // adjust by half warp. From 484efd8eae9165db15f8066ea6757f53c769e0aa Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 3 May 2022 11:15:55 -0700 Subject: [PATCH 37/57] update comment --- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index f511501ff11c..7740802c9330 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -23,17 +23,20 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) { // LdMatrix has .x1, .x2 and .x4 options, currently we actively use .x2 and // .x4. In .x2 option. the the address register of upper half warp (lane 16-31) -// are un-used but on Turing [sm75,sm80) architecture these un-used address +// are un-used but on Turing [sm75,sm80) architecture these un-used addresses // need to be valid, in the sense that: // 1. The data it points to has to be within allocated shared mem buffer. // 2. The address needs to be aligned to 16 byte. +// See also: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix // This function addresses 2. above by masking out the sub-16B component // of the address in upper warp and 1. is guaranteed by ldmatrix swizzle // util. -// This will not affect any functionality since the adjusted address -// are not used. This is just modification to satisfy the address -// requirement on Turing. The requirement is lifted in sm80+, so -// this function is a no-op on Ampere or above. +// This will **not** affect any functionality. This is just modification +// of unused pointers to satisfy the alignment requirement on Turing +// hardware. +// The alignment requirement is lifted on sm80+, +// so this function is a no-op on Ampere or above. DEVICE_INLINE void adjustPartialLdMatrixAddrInTuring(unsigned& addr_in_byte) { #if (__CUDA_ARCH__ < 800) const unsigned thread_id = threadIdx.x; From 9e298b06bb771d8e157b6966a6e84e73f0387c0b Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 4 May 2022 15:09:51 -0700 Subject: [PATCH 38/57] cleanup;undo lower_unroll change; --- .../cuda/lower_predicate_elimination.cpp | 10 +++-- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 43 +------------------ torch/csrc/jit/codegen/cuda/mutator.cpp | 2 +- 3 files changed, 8 insertions(+), 47 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index c2928619f785..39f66056022c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -275,7 +275,7 @@ class PredicateChcker : public IterVisitor { if (id->isThreadDim()) { auto ptype = id->getParallelType(); // Need to predicate to avoid out of bound access - // because of over-subscribed block size. + // because of over-subscribed block size. if (!parallel_dimension_map.isExact(ptype)) { return true; } @@ -301,6 +301,8 @@ class PredicateChcker : public IterVisitor { } } } + + return false; } // Utility to find the leaf iterdomains of the given @@ -308,10 +310,10 @@ class PredicateChcker : public IterVisitor { // in the indexing pass. std::vector getZeroLeafIds(const TensorView* tv) const { TORCH_INTERNAL_ASSERT( - output->getMemoryType() == MemoryType::Local || - output->getMemoryType() == MemoryType::Shared, + tv->getMemoryType() == MemoryType::Local || + tv->getMemoryType() == MemoryType::Shared, "Local or shared memory tensor is assumed: ", - output->toString()); + tv->toString()); bool is_shared_mem = tv->getMemoryType() == MemoryType::Shared; std::vector zero_leaf_ids; for (const auto i : c10::irange(tv->nDims())) { diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 2e45d5e9adc9..434d1711d9c8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -109,48 +109,7 @@ void UnrollPass::handle(Expr* expr) { // Vectorized expressions should never use inline predicates kir::Predicate* pred = nullptr; - - bool is_ld_matrix = false; - bool is_ld_matrix_producer = false; - if (ir_utils::isLdMatrixOp(expr)) { - // Note & TODO: currently cannot support predicated ldmatrix since it'd - // need to make sure all threads in a warp evaluate the predicate - // to the same value. - // For now asserting that the predicate for ldmatrix and its producer - // can be omitted. Would need to build out support for warp level - // predicates if we want vectorization predicates to be active on - // these ops. - TORCH_INTERNAL_ASSERT( - GpuLower::current()->predicateElimination().canOmitPredicate(expr), - "Vectorize: unsupported predicate for warp ops"); - is_ld_matrix = true; - } - - // Note & TODO: This part is a WAR that disables any predicated - // producer of ldmatrix consumers. Ldmatrix currently cannot - // be predicated due to no support in warp-uniform predicates. - // A predicated producer will result in ldmatrix propagating - // data outside of the producer's valid region. - // Except for vectorized global loads which should be properly - // initialized outside of the predicated region. - for (auto out_tv : ir_utils::filterByType(expr->outputs())) { - for (auto use : expr->fusion()->unordered_uses(out_tv)) { - if (ir_utils::isLdMatrixOp(use)) { - TORCH_INTERNAL_ASSERT( - GpuLower::current()->predicateElimination().canOmitPredicate( - expr), - "Vectorize: producers of ldmatrix cannot be predicated except for global loads:", - expr); - is_ld_matrix_producer = true; - break; - } - } - } - - const bool omit_vector_predicate_for_ldmatrix = - is_ld_matrix || is_ld_matrix_producer; - - if (!unswitched_loop_ && !omit_vector_predicate_for_ldmatrix && + if (!unswitched_loop_ && std::any_of( for_loops_.begin(), for_loops_.end(), [](const kir::ForLoop* fl) { return fl->iter_domain()->getParallelType() == diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index c22388d32b5a..6d59d43dee0d 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -301,7 +301,7 @@ void OptOutMutator::mutate(LoadStoreOp* ldst) { auto container = ldst->container(); container->removeExpr(ldst); - auto new_ldst = IrBuilder::create(container, op_type, out, in); + IrBuilder::create(container, op_type, out, in); } void OptOutMutator::mutate(BroadcastOp* bop) { From 96151967242fc7f5bd7e9810569c05306c1c2374 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 4 May 2022 15:15:03 -0700 Subject: [PATCH 39/57] rebase fix --- torch/csrc/jit/codegen/cuda/lower_index.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 3abea2536bf9..6442e4cd166d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -611,6 +611,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { const auto in = lowerSrcIndex(ldst->in(), ldst->out()); const auto out = lowerDstIndex(ldst->out()); pushBack(IrBuilder::create(ldst->opType(), out, in)); + GpuLower::current()->propagateExprInfo(ldst, back()); } void IndexLowering::handle(const MmaOp* mma) { From a787b59a73cdd68f41548648207a0935fd0a62a8 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 4 May 2022 15:58:09 -0700 Subject: [PATCH 40/57] fix shared mem init --- .../cuda/lower_predicate_elimination.cpp | 66 ++++++++++++------- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 11 ++++ torch/csrc/jit/codegen/cuda/lower_utils.h | 4 ++ 3 files changed, 59 insertions(+), 22 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 39f66056022c..0c0d5d7109a2 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -23,12 +23,10 @@ namespace { // predicating these ops will require extra steps to ensure that // the whole warp will get the same value. void assertOnWarpOps(const Expr* expr) { - if (auto uop = dynamic_cast(expr)) { - TORCH_INTERNAL_ASSERT( - !ir_utils::isLdMatrixOp(expr), - "Predicate elimination: cannot eliminate pred for ldmatrix, use exact parallel dims", - expr); - } + TORCH_INTERNAL_ASSERT( + !ir_utils::isLdMatrixOp(expr), + "Predicate elimination: cannot eliminate pred for ldmatrix, use exact parallel dims", + expr); TORCH_INTERNAL_ASSERT( !expr->isA(), "Mma op: cannot eliminate predicate for mma op, tiling not valid. ", @@ -39,6 +37,26 @@ void assertOnWarpOps(const Expr* expr) { namespace { +// Utility to check if the scheduled domain of the given +// TensorView represent an exact shared mem access, meaning +// that all the thread parallel dimensions on the leaf nodes +// are exact so that the shared mem read/write would not +// run out of bound because of thread over-subscription. +bool isExactParallelSharedMemAccess(TensorView* tv) { + auto& parallel_dimension_map = GpuLower::current()->parallelDimensionMap(); + for (auto id : tv->domain()->domain()) { + if (id->isThreadDim()) { + auto ptype = id->getParallelType(); + // Need to predicate to avoid out of bound access + // because of over-subscribed block size. + if (!parallel_dimension_map.isExact(ptype)) { + return false; + } + } + } + return true; +} + class PredicateAnalyzer : public OptOutDispatch { public: //! Checks if a predicate is needed to avoid out-of-bound accesses. @@ -270,17 +288,15 @@ class PredicateChcker : public IterVisitor { bool needSharedMemPredicate(TensorView* producer, TensorView* consumer) const { // Indexing is based on consumer leaf ids so check the consumer. - auto& parallel_dimension_map = GpuLower::current()->parallelDimensionMap(); - for (auto id : consumer->domain()->domain()) { - if (id->isThreadDim()) { - auto ptype = id->getParallelType(); - // Need to predicate to avoid out of bound access - // because of over-subscribed block size. - if (!parallel_dimension_map.isExact(ptype)) { - return true; - } - } + // If consumer schedule contains in-exact thread parallel + // dimensions, need to predicate against out of bound + // shared memory access by out of bound threads. + if (!isExactParallelSharedMemAccess(consumer)) { + return true; + } + + for (auto id : consumer->domain()->domain()) { // TODO: (Enable in a follow up) // smem predicate removal with init would break unroll and unswitch, // eg. as in issue 1133, so disabling this removal pattern for now. @@ -750,14 +766,20 @@ bool PredicateElimination::canOmitPredicate(const Expr* expr) const { TORCH_INTERNAL_ASSERT(expr != nullptr); const auto out_tv = ir_utils::getTvOutput(expr); TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression"); - // No need to predicate local tensors to which a scalar is assigned - if (out_tv->getMemoryType() == MemoryType::Local) { - if (auto uop = dynamic_cast(expr)) { - if (uop->getUnaryOpType() == UnaryOpType::Set && uop->in()->isScalar()) { - return true; - } + + if (ir_utils::isTensorScalarFillOp(expr)) { + if (out_tv->getMemoryType() == MemoryType::Local) { + // Filling a local tensor with scalar shouldn't + // need any predicate currently. + return true; + } else if (out_tv->getMemoryType() == MemoryType::Shared) { + // A shared memory initialization should be same except + // that we'd need a predicate to guard against out of + // bound access by out of inexact threads. + return isExactParallelSharedMemAccess(out_tv); } } + if (non_predicated_exprs_.find(expr) != non_predicated_exprs_.end()) { return true; } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 06cc625a855b..943ec371e77a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -118,6 +118,17 @@ bool isLdMatrixOp(const Expr* expr) { return false; } +bool isTensorScalarFillOp(const Expr* expr) { + if (auto uop = dynamic_cast(expr)) { + if (uop->getUnaryOpType() == UnaryOpType::Set) { + if (uop->in()->isScalar()) { + return true; + } + } + } + return false; +} + TensorView* getTv(Val* val) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) return const_cast(getTv(const_cast(val))); diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 47373d12f4c2..58352c912b37 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -115,6 +115,10 @@ kir::Allocate* allocGlobalBufferForGridComm( //! a ldmatrix intrinsic. bool isLdMatrixOp(const Expr* expr); +//! Returns true if the given expression fills the output +//! tensor with a single scalar. +bool isTensorScalarFillOp(const Expr* expr); + } // namespace ir_utils namespace loop_utils { From 7cf98abbe39a537c4b9b9f06831933f6e003be40 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 4 May 2022 16:08:44 -0700 Subject: [PATCH 41/57] clean up --- .../jit/codegen/cuda/lower_predicate_elimination.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 0c0d5d7109a2..b4c8d69f0018 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -311,7 +311,7 @@ class PredicateChcker : public IterVisitor { // lift this one once the generic pred removal pass with fusion // traversal is ready. auto consumer_def = consumer->definition(); - if (consumer_def->isA() || consumer_def->isA()) { + if (ir_utils::isReductionOp(consumer_def)) { if (producer->getMemoryType() == MemoryType::Shared) { return true; } @@ -625,7 +625,12 @@ class PredicateChcker : public IterVisitor { // predicate. In fact this is the only way we can // use mma at the moment since we could not predicate // mma ops without guaranteeing warp uniform results. - if (mma->fusion()->unordered_uses(input).size() > 1) { + auto input_init = + GpuLower::current()->predicateElimination().getInitValue(input); + + // TODO: + // clean up this to support more generic prolog fusion. + if (input_init != nullptr && !input_init->sameAs(mma->init())) { // This is a WAR at the moment. We would need to propagate // initialization information from PredicateElimination // pass to most accurately detect if the input is @@ -633,8 +638,6 @@ class PredicateChcker : public IterVisitor { // This could also be fixed when we have the traversal // based predicate elimination and initialization pass // ready. Would be easy to clean up this part at that point. - // TODO: - // clean up this to support more generic prolog fusion. needs_predicate_ = true; return; } From e2e07070c8c34479586aa67b694f8a180b990d35 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 4 May 2022 16:23:19 -0700 Subject: [PATCH 42/57] comment --- .../jit/codegen/cuda/lower_predicate_elimination.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index b4c8d69f0018..efd1669c37fe 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -66,15 +66,13 @@ class PredicateAnalyzer : public OptOutDispatch { //! local memory. However, accessing producer tensors still may //! result in out-of-bound as they are replayed as consumers. static bool needsPredicate(TensorView* producer, TensorView* consumer) { - // Both tensors must be on local memory. Global tensors must be + // Both tensors must be on local or shared memory. Global tensors must be // predicated as allocation is done based on root domains. Smem - // and local tensors are allocated based on leaf domains, however, - // smem tensors are parallelized, which is highly likely, the size + // and local tensors are allocated based on leaf domains. + // However, smem tensors are parallelized, which is highly likely, the size // of the parallelized axis is the actual size of the axis, not - // the number of threads. Since the number of threads can be - // larger than the axis size, it's not safe to skip predication - - // Check that parallel dimension will not generate out of bound index + // the number of threads. This is currently actively checked to avoid + // out of bound shared mem access by out of bound threads. if (producer->getMemoryType() == MemoryType::Global || consumer->getMemoryType() == MemoryType::Global) { return true; From 9c224dcc5df0eea3ca59604c49379828b916d476 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 4 May 2022 16:44:14 -0700 Subject: [PATCH 43/57] zero leaf update --- torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index efd1669c37fe..b84089b8938c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -332,7 +332,7 @@ class PredicateChcker : public IterVisitor { std::vector zero_leaf_ids; for (const auto i : c10::irange(tv->nDims())) { auto leaf_id = tv->axis(i); - if (is_shared_mem && leaf_id->isThread()) { + if (is_shared_mem && leaf_id->isThreadDim()) { // Thread parallel axes on shared mem are never // zero loops as each thread owns its share // of the shared mem space. @@ -343,7 +343,7 @@ class PredicateChcker : public IterVisitor { // of CA axes are zero loops. i < tv->getComputeAtPosition() || // Parallel axes on local mem is zero loop. - // FIXME: check un-mapped case. + // Grid axes on shared mem is zero loop. leaf_id->isThread() || // Mma axes, similar to vectorization, are // implicit in hardware intrinsics, and thus From 584567cbca8d06cabf1691276a6959337bf1267e Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 4 May 2022 16:52:43 -0700 Subject: [PATCH 44/57] comment --- torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index b84089b8938c..09bd891ccfdb 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -68,7 +68,7 @@ class PredicateAnalyzer : public OptOutDispatch { static bool needsPredicate(TensorView* producer, TensorView* consumer) { // Both tensors must be on local or shared memory. Global tensors must be // predicated as allocation is done based on root domains. Smem - // and local tensors are allocated based on leaf domains. + // and local tensors are allocated based on leaf domains. // However, smem tensors are parallelized, which is highly likely, the size // of the parallelized axis is the actual size of the axis, not // the number of threads. This is currently actively checked to avoid @@ -628,6 +628,8 @@ class PredicateChcker : public IterVisitor { // TODO: // clean up this to support more generic prolog fusion. + // Will need additional analysis passes on initialization + // propagation and further predicate placement on top. if (input_init != nullptr && !input_init->sameAs(mma->init())) { // This is a WAR at the moment. We would need to propagate // initialization information from PredicateElimination From 09f5ce0fd61834a3c6cbeba3185139836a80c512 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 4 May 2022 17:01:02 -0700 Subject: [PATCH 45/57] use toString --- .../cuda/lower_predicate_elimination.cpp | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 09bd891ccfdb..276192fbecb7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -26,7 +26,7 @@ void assertOnWarpOps(const Expr* expr) { TORCH_INTERNAL_ASSERT( !ir_utils::isLdMatrixOp(expr), "Predicate elimination: cannot eliminate pred for ldmatrix, use exact parallel dims", - expr); + expr->toString()); TORCH_INTERNAL_ASSERT( !expr->isA(), "Mma op: cannot eliminate predicate for mma op, tiling not valid. ", @@ -456,7 +456,7 @@ class PredicateChcker : public IterVisitor { // predication for expressions involving global memory, this // should never occur. TORCH_INTERNAL_ASSERT( - input_def != nullptr, "Inconsistent input found: ", input); + input_def != nullptr, "Inconsistent input found: ", input->toString()); // The input needs to be initialized to the init value to omit // the predicate, so if the input has its own init value, i.e., @@ -516,7 +516,9 @@ class PredicateChcker : public IterVisitor { // predication for expressions involving global memory, this // should never occur. TORCH_INTERNAL_ASSERT( - input_def != nullptr, "Inconsistent input found: ", input); + input_def != nullptr, + "Inconsistent input found: ", + input->toString()); // The input needs to be initialized to the init value to omit // the predicate, so if the input has its own init value, i.e., @@ -548,7 +550,9 @@ class PredicateChcker : public IterVisitor { // predication for expressions involving global memory, this // should never occur. TORCH_INTERNAL_ASSERT( - input_def != nullptr, "Inconsistent input found: ", input); + input_def != nullptr, + "Inconsistent input found: ", + input->toString()); // The input needs to be initialized to the init value to omit // the predicate, so if the input has its own init value, i.e., @@ -608,7 +612,9 @@ class PredicateChcker : public IterVisitor { for (auto input : ir_utils::filterByType(mma->inputs())) { auto input_def = input->definition(); TORCH_INTERNAL_ASSERT( - input_def != nullptr, "Inconsistent input found: ", input); + input_def != nullptr, + "Inconsistent input found: ", + input->toString()); Val* input_init = ir_utils::getReductionInitValOf(input); if (input_init != nullptr && !mma->init()->sameAs(input_init)) { @@ -751,9 +757,9 @@ bool PredicateElimination::setReductionInitValue( "Incosistent setting of initialization value for t", tv->name(), ". Prev: ", - existing_val, + existing_val->toString(), ", New: ", - reduction_init); + reduction_init->toString()); return false; } } From 998adb575e53b8c33e66ec453336335062f3d338 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 6 May 2022 10:09:56 -0700 Subject: [PATCH 46/57] comment --- .../codegen/cuda/lower_predicate_elimination.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 276192fbecb7..094062feff10 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -322,6 +322,8 @@ class PredicateChcker : public IterVisitor { // Utility to find the leaf iterdomains of the given // tensor view that will be treated as "zero loops" // in the indexing pass. + // For details on zero loops, see indexMapFromTV in + // lower index pass. std::vector getZeroLeafIds(const TensorView* tv) const { TORCH_INTERNAL_ASSERT( tv->getMemoryType() == MemoryType::Local || @@ -636,6 +638,17 @@ class PredicateChcker : public IterVisitor { // clean up this to support more generic prolog fusion. // Will need additional analysis passes on initialization // propagation and further predicate placement on top. + // More TODO: + // Even when producer is initialized, it is still generally + // not safe to remove predicate around reduction ops if the + // producer is not predicated. + // On the other side, we do have patterns like ldmatrix->mma where + // both producer and consumer cannot be safely predicated without + // guaranteeing warp uniform results. + // This is currently a WAR and relies on validation pass to exclude + // complex prolog patterns in mma based matmul kernels. Will + // definitely need to revisit and build out predicate and + // initialization analysis pass to better handle this case. if (input_init != nullptr && !input_init->sameAs(mma->init())) { // This is a WAR at the moment. We would need to propagate // initialization information from PredicateElimination From 13266ced477d21273af1237c98c6a22c024e3bcb Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 6 May 2022 10:19:27 -0700 Subject: [PATCH 47/57] expand tensor filling op detect to include load store --- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 22 ++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 943ec371e77a..c82ce62edeca 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -119,13 +119,25 @@ bool isLdMatrixOp(const Expr* expr) { } bool isTensorScalarFillOp(const Expr* expr) { - if (auto uop = dynamic_cast(expr)) { - if (uop->getUnaryOpType() == UnaryOpType::Set) { - if (uop->in()->isScalar()) { - return true; - } + // Check that the input is a single scalar. + if (expr->inputs().size() == 1 && expr->input(0)->isScalar()) { + // All load store op with a single scalar input + // should be a scalar filling op. Semantically + // it literally means `Store`'ing a scalar + // into a tensor. + if (expr->isA()) { + return true; + } + // Unary copy op is also a scalar filling op. + if (auto uop = dynamic_cast(expr)) { + return uop->getUnaryOpType() == UnaryOpType::Set; } } + // Ideally any scalar expression that outputs + // to a tensor should be considered in this function + // but since we currently only limit scope to + // initialization patterns so other scalar expr's + // are low priority and are excluded here to avoid confusion. return false; } From fc35480b231a8cbc9294d5f5212efcf211501b03 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 6 May 2022 15:36:52 -0700 Subject: [PATCH 48/57] arch guard update --- torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu index 01d27b343947..7d0c6be7c2be 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu @@ -212,7 +212,7 @@ DEVICE_INLINE void initM16N16K4NT(Array* accumulator) { } // namespace Volta -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) namespace Ampere { @@ -272,6 +272,6 @@ DEVICE_INLINE void M16N8K16TN( } // namespace Ampere -#endif // Arch 75 +#endif // Arch 80 #undef DEVICE_INLINE From 76563a5754bd457304efcbc8725545f9e3b8fe15 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 6 May 2022 16:03:57 -0700 Subject: [PATCH 49/57] fix rebase --- torch/csrc/jit/codegen/cuda/lower_predicate.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index cda210989f17..bfac3f165008 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -56,7 +56,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { "Expecting predicated body to only have one vectorized expression."); auto vec_expr = ite->thenBody()[0]; TORCH_INTERNAL_ASSERT( - vec_expr->isA(), + vec_expr->isA() || vec_expr->isA(), "Vectorize predicate exprs only supported on set operations."); TORCH_INTERNAL_ASSERT( ir_utils::isTvOp(vec_expr), From 119f0eb3dc22de3672779b8ebe79248683559578 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 6 May 2022 20:32:45 -0700 Subject: [PATCH 50/57] WAR for buffer re-use --- .../cuda/lower_predicate_elimination.cpp | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 8e011c4cd907..3418f4442f7d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -278,6 +278,7 @@ class PredicateChcker : public IterVisitor { } } } + return false; } @@ -294,6 +295,47 @@ class PredicateChcker : public IterVisitor { return true; } + // TODO: This is directed WAR on FusionPersistentNormLocalShared. + // This use case along with other previous issues motivate a + // joint optimization of predicate removal and buffer reuse. + // In this particular case: + // __shared__ T0 [10], T1[10] + // for i in ... + // T1[i] = T0[i] + ... // exp0 + // T2 = 0; // init for exp1 + // if(pred) + // T2 = T1 ... // exp1 + // If we remove pred around expr1, as the way the pred removal + // pass is set up, the init for expr will be pushed up to + // initialize T1 instead. + // However if we initialize T1, the code will look like: + // for i in ... + // T1[i] = 0; + // for i in ... + // if(pred) + // T1[i] = T0[i] + ... + // Note that we'd be able to reuse buffer of T0 for T1 but + // if we initialze T1 we cannot do that and thus the + // kernel would not fit in smaller devices. + if (producer->getMemoryType() == MemoryType::Shared) { + if (auto producer_def = producer->definition()) { + if (std::any_of( + producer_def->inputs().begin(), + producer_def->inputs().end(), + [](Val* val) { + if (auto tv = ir_utils::getTv(val)) { + return tv->getMemoryType() == MemoryType::Shared; + } + return false; + })) { + // Disable shared memory producers that is a consumer + // of another shared memory tensor. The initialization would + // break potential opportunity to re-use shared mem buffer. + return true; + } + } + } + for (auto id : consumer->domain()->domain()) { // TODO: (Enable in a follow up) // smem predicate removal with init would break unroll and unswitch, From abcd1e85f0d81bb21171fc34c0d05dacf0ecf25d Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 9 May 2022 13:49:29 -0700 Subject: [PATCH 51/57] clean up --- .../jit/codegen/cuda/scheduler/mma_utils.cpp | 91 ++++++++++++------- 1 file changed, 57 insertions(+), 34 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 1b467d564f0b..258688f399dc 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -260,27 +260,37 @@ namespace { // Utility to check operand innermost scheduled dimensions void validateInnerMNK(TensorView* tv, MmaOptions options, int m, int n, int k) { TORCH_INTERNAL_ASSERT(tv->nDims() >= 3); - TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::M), - tv->axis(-3), - m)); - TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::N), - tv->axis(-2), - n)); - TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::K), - tv->axis(-1), - k)); + TORCH_INTERNAL_ASSERT( + canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::M), + tv->axis(-3), + m), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); + TORCH_INTERNAL_ASSERT( + canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::N), + tv->axis(-2), + n), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); + TORCH_INTERNAL_ASSERT( + canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::K), + tv->axis(-1), + k), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); } void validateResultInnerMN(TensorView* tv, int m, int n) { TORCH_INTERNAL_ASSERT(tv->nDims() >= 2); int root_dim = tv->getMaybeRFactorDomain().size(); - TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( - tv->getMaybeRFactorDomain()[root_dim - 2], tv->axis(-2), m)); - TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( - tv->getMaybeRFactorDomain()[root_dim - 1], tv->axis(-1), n)); + TORCH_INTERNAL_ASSERT( + canValidateIsInnerDim( + tv->getMaybeRFactorDomain()[root_dim - 2], tv->axis(-2), m), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); + TORCH_INTERNAL_ASSERT( + canValidateIsInnerDim( + tv->getMaybeRFactorDomain()[root_dim - 1], tv->axis(-1), n), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); } //! Performs checks on tv given to schedule ld matrix. @@ -289,7 +299,12 @@ void validateResultInnerMN(TensorView* tv, int m, int n) { //! 2. direct output of a broadcast op following a ldmatrix op //! Returns true if the tv is an immediate output of ldmatrix op //! -//! TODO: this check is +//! TODO: this check is a WAR with pattern matching for now. +//! The two patterns mentioned above are the only supported use +//! cases of ldmatrix currently. This restriction can be greatly +//! relaxed after the iterdomain swizzle infrastructure, which +//! will provide the capability to directly model the exact +//! data format of ldmatrix output. bool checkLdMatrixTv(TensorView* tv) { // First check if tv is an ldmatrix output: auto tv_def = tv->definition(); @@ -406,14 +421,18 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { if (options.operand == MmaOptions::Operand::A) { TORCH_INTERNAL_ASSERT(tv->nDims() >= 2); // validation: - TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::M), - tv->axis(-2), - 16)); - TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::K), - tv->axis(-1), - 16)); + TORCH_INTERNAL_ASSERT( + canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::M), + tv->axis(-2), + 16), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); + TORCH_INTERNAL_ASSERT( + canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::K), + tv->axis(-1), + 16), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); //[16m, 16k] tv->split(-2, 8); @@ -437,14 +456,18 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { tv->axis(-2)->parallelize(ParallelType::TIDx); } else if (options.operand == MmaOptions::Operand::B) { // validation: - TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::N), - tv->axis(-2), - 8)); - TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::K), - tv->axis(-1), - 16)); + TORCH_INTERNAL_ASSERT( + canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::N), + tv->axis(-2), + 8), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); + TORCH_INTERNAL_ASSERT( + canValidateIsInnerDim( + getMmaOperandRootDimension(tv, options, MmaDimension::K), + tv->axis(-1), + 16), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); if (transposed) { // [8, 16] @@ -461,7 +484,7 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { tv->split(-1, 4); tv->split(-2, 2); - // 0 1 2 3 4 + // 0 1 2 3 //[8, oo2,oi2,i4] tv->reorder({{-4, -2}, {-2, -4}}); From ebc215ad8d7fd9536329b9790526c4342d1001c5 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 9 May 2022 13:56:12 -0700 Subject: [PATCH 52/57] update comment --- torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 3418f4442f7d..fbb2e6cd6529 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -301,7 +301,8 @@ class PredicateChcker : public IterVisitor { // In this particular case: // __shared__ T0 [10], T1[10] // for i in ... - // T1[i] = T0[i] + ... // exp0 + // if(pred) + // T1[i] = T0[i] + ... // exp0 // T2 = 0; // init for exp1 // if(pred) // T2 = T1 ... // exp1 From f0eb7b60181ec91b4e3a6faf216e24f5b10c9ee2 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 23 May 2022 17:03:53 -0700 Subject: [PATCH 53/57] use relaxed arch guard --- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 918791518128..3d012e7d33ea 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -2234,8 +2234,6 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // MMA unit test on Turing TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); - Fusion fusion; FusionGuard fg(&fusion); @@ -2293,7 +2291,9 @@ TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { auto t1 = at::randn({8, 16}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); @@ -2303,8 +2303,6 @@ TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { // MMA unit test on Turing TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); - Fusion fusion; FusionGuard fg(&fusion); @@ -2366,7 +2364,9 @@ TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { auto t1 = at::randn({16, 8}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -2376,8 +2376,6 @@ TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { // MMA unit test on Turing TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); - Fusion fusion; FusionGuard fg(&fusion); @@ -2442,7 +2440,9 @@ TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { auto t1 = at::randn({16, 8}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -2452,8 +2452,6 @@ TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { // Matmul test on Turing TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); - Fusion fusion; FusionGuard fg(&fusion); int M = 511, N = 257, K = 88; @@ -2590,7 +2588,9 @@ TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { auto t1 = at::randn({N, K}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); + auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); @@ -2600,8 +2600,6 @@ TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { // Matmul test on Turing TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); - Fusion fusion; FusionGuard fg(&fusion); int M = 512, N = 256, K = 128; @@ -2739,6 +2737,8 @@ TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { auto t1 = at::randn({K, N}, options); FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); @@ -2749,8 +2749,6 @@ TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { // Matmul test on Turing TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { - NVFUSER_TEST_CUDA_ARCH_GUARD(7, 5); - Fusion fusion; FusionGuard fg(&fusion); int M = 512, N = 256, K = 128; @@ -2889,6 +2887,8 @@ TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { auto t1 = at::randn({K, N}, options); FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {t0, t1})); fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); From 0e204aa030f64f4e08e560595c3a9b1085b63408 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 23 May 2022 17:12:51 -0700 Subject: [PATCH 54/57] fix rebase --- torch/csrc/jit/codegen/cuda/mma_type.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index 5c9218d79e72..9ac067eb8b77 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -86,7 +86,7 @@ bool isVolta(MmaOptions::MacroType macro) { } bool isTuring(MmaOptions::MacroType macro) { - return false; + return macro == MmaOptions::MacroType::Turing_16_8_16; } bool isAmpere(MmaOptions::MacroType macro) { From e92430e793a2cc0349e0b0a8140835539b7b5eb5 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 23 May 2022 17:32:54 -0700 Subject: [PATCH 55/57] rebase fix --- torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 3d012e7d33ea..c671001880be 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -2739,7 +2739,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 7, 5, fe.compileFusion(&fusion, {t0, t1})); - fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); @@ -2889,7 +2889,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( 7, 5, fe.compileFusion(&fusion, {t0, t1})); - fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); From 37ce5aa332a92633a2514e31f853f06280b4de2a Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 23 May 2022 22:55:44 -0700 Subject: [PATCH 56/57] constexpr --- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 7740802c9330..a4745143a99b 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -43,9 +43,9 @@ DEVICE_INLINE void adjustPartialLdMatrixAddrInTuring(unsigned& addr_in_byte) { // Upper half warp has 8 bytes offset from aligned in .x2 option // of ldmatrix. Currently no support for .x1 so assume always // adjust by half warp. - const unsigned half_warp = 16; + constexpr unsigned half_warp = 16; // Need to adjust to 16 byte alignment, mask out un-aligned component. - const unsigned mask_out = 16 - 1; + constexpr unsigned mask_out = 16 - 1; // Adjust only in upper half warp. // use bit math to reduce strength if (thread_id & half_warp) { From 3d8cb3a1ab8dcc10b9ce1e73045aa6037fcf4539 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 23 May 2022 23:10:02 -0700 Subject: [PATCH 57/57] rename swizzle enum --- torch/csrc/jit/codegen/cuda/mma_type.h | 4 +- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 2 +- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 92 +++++++++---------- 3 files changed, 49 insertions(+), 49 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index 610d50c233fa..6b94d74a4f5b 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -74,7 +74,7 @@ struct MmaOptions { enum class MmaInputLayout { NT = 0, TT, TN }; //! Utility to annotate which input of mma this option struct describes - enum class Operand { NotOperand = 0, A, B }; + enum class Operand { Accumulator = 0, A, B }; //! Utility to annotate which mma macro this config uses. MacroType macro = MacroType::NoMMA; @@ -118,7 +118,7 @@ class TORCH_CUDA_CU_API MmaBuilder { //! Specifies which element in the mma op this builder is generating //! parameters for, i.e. A or B. This is useful when generating //! data swizzles for different elements of mma. - //! - Operand::NotOperand means the parameters describe accumulator in mma + //! - Operand::Accumulator means the parameters describe accumulator in mma //! op. //! - This option is ignored when configuring the mma operator itself. MmaBuilder& operand(MmaOptions::Operand a_or_b); diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 2bcb905173d7..d4d13a6a1fd7 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -1055,7 +1055,7 @@ bool TensorView::isEmptyTensor() const { void TensorView::applyMmaSwizzle(MmaOptions options) { switch (options.operand) { - case MmaOptions::Operand::NotOperand: + case MmaOptions::Operand::Accumulator: mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(this, options); break; case MmaOptions::Operand::A: diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index c671001880be..4bb2542f10f6 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -189,9 +189,9 @@ TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { // Schedule the output instruction tile. // Assumes last 3 dims are mnk tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Set memory type. tv0cw->setMemoryType(MemoryType::Shared); @@ -255,9 +255,9 @@ TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -323,9 +323,9 @@ TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { tv2c->reorder({{0, 2}, {1, 0}, {2, 1}}); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -550,12 +550,12 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTT_CUDA) { // Use WarpMmaSwizzler for the innermost instruction tile (Mi,Ni, Ki) on // output tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // -6 -5 -4 -3 -2 -1 // [Mwo Nwo Mw Nw Mi Ni] tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Inline broadcast with smem write. tv0b->computeAt(tv0cw, -2); @@ -705,9 +705,9 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0b->computeAt(tv0cw, -2); tv1b->computeAt(tv1cw, -2); @@ -858,9 +858,9 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0b->computeAt(tv0cw, -2); tv1b->computeAt(tv1cw, -2); @@ -944,9 +944,9 @@ TEST_F(NVFuserTest, FusionAmpereMMATN_CUDA) { tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -1018,9 +1018,9 @@ TEST_F(NVFuserTest, FusionAmpereMMATT_CUDA) { // [M,K,N] -> [M,N,K] tv2c->reorder({{-2, -1}, {-1, -2}}); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -1097,9 +1097,9 @@ TEST_F(NVFuserTest, FusionAmpereMMANT_CUDA) { // [K,M,N] -> [M,N,K] tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -1235,9 +1235,9 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTN_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10 @@ -1387,9 +1387,9 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTT_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10 @@ -1540,9 +1540,9 @@ TEST_F(NVFuserTest, FusionAmpereMatmulNT_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10 @@ -1732,9 +1732,9 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv4c->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); tv4->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); // Schedule gemm 1: // ------------------------------------------------------------------ @@ -1802,13 +1802,13 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv3c->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3cw->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3h->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3cw->setMemoryType(MemoryType::Shared); // Parallelize @@ -2034,9 +2034,9 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv4c->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); tv4->applyMmaSwizzle( - mma_builder2.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); // Schedule gemm 1: // ------------------------------------------------------------------ @@ -2112,9 +2112,9 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // // // --------------------------------------------------------------------------- tv3c->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); tv3->applyMmaSwizzle( - mma_builder1.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); // mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(tv3ccw, // mma_builder1.build()); @@ -2279,9 +2279,9 @@ TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -2352,9 +2352,9 @@ TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { // [M,K,N] -> [M,N,K] tv2c->reorder({{-2, -1}, {-1, -2}}); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -2428,9 +2428,9 @@ TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { // [K,M,N] -> [M,N,K] tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv0cw->setMemoryType(MemoryType::Shared); tv1cw->setMemoryType(MemoryType::Shared); @@ -2565,9 +2565,9 @@ TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10 @@ -2714,9 +2714,9 @@ TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10 @@ -2864,9 +2864,9 @@ TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { // Schedule mma output // --------------------------------------------------------------------------- tv2c->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); tv2->applyMmaSwizzle( - mma_builder.operand(MmaOptions::Operand::NotOperand).build()); + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); // Parallelize // 0 1 2 3 4 5 6 7 8 9 10