diff --git a/build_variables.bzl b/build_variables.bzl index f70d4280825af..517a4c8887230 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -707,6 +707,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/lower_validation.cpp", "torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp", "torch/csrc/jit/codegen/cuda/lower2device.cpp", + "torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp", "torch/csrc/jit/codegen/cuda/manager.cpp", "torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp", "torch/csrc/jit/codegen/cuda/mutator.cpp", diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 38c16f0dca676..af08eccd6a63f 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -251,6 +252,27 @@ void FusionExecutor::compileFusion( kernel->print(); } + if (isDebugDumpEnabled(DebugDumpOption::BankConflictInfo)) { + auto bank_conflict_info = getBankConflictInfo(kernel); + if (bank_conflict_info.empty()) { + std::cout << "===== No bank confliction =====" << std::endl; + } else { + std::cout << "======= Bank confliction =======" << std::endl; + for (auto info : bank_conflict_info) { + std::cout << "Expr: " << info.first->toString() << std::endl; + auto conflict = info.second; + if (conflict.first > 1) { + std::cout << "input conflict: " << conflict.first << " way, "; + } + if (conflict.second > 1) { + std::cout << "output conflict: " << conflict.second << " way"; + } + std::cout << std::endl; + } + std::cout << "================================" << std::endl; + } + } + kernel_code_ = codegen::generateCudaKernel(kernel, kernelName()); const auto structured_code = getStructuredCode(kernel_code_); diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 7dda464a4fac8..9527520f6041f 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -54,7 +54,11 @@ void ExpressionEvaluator::bind(Val* value, const IntOrDouble& concrete_value) { TORCH_CHECK( value->definition() == nullptr, "Tried to bind to a value that is computed in the fusion IR"); - known_values_[value] = concrete_value; + if (value->isA()) { + known_named_scalars_[value->as()->name()] = concrete_value; + } else { + known_values_[value] = concrete_value; + } } c10::optional ExpressionEvaluator::evaluate(Val* value) { @@ -88,7 +92,7 @@ void ExpressionEvaluator::print() const { c10::optional ExpressionEvaluator::getValue(Val* value) { TORCH_INTERNAL_ASSERT( value->isAnInt() || value->isADouble(), - "Expression Evaluation does not support values other than integers at this time."); + "Expression Evaluation does not support values other than integers/doubles at this time."); if (value->getValType().value() == ValType::Scalar) { if (value->isAnInt() && value->as()->value().has_value()) { @@ -99,9 +103,16 @@ c10::optional ExpressionEvaluator::getValue(Val* value) { } } - const auto it = known_values_.find(value); - return it != known_values_.end() ? c10::optional(it->second) - : c10::nullopt; + if (value->isA()) { + const auto it = known_named_scalars_.find(value->as()->name()); + return it != known_named_scalars_.end() + ? c10::optional(it->second) + : c10::nullopt; + } else { + const auto it = known_values_.find(value); + return it != known_values_.end() ? c10::optional(it->second) + : c10::nullopt; + } } void ExpressionEvaluator::handle(UnaryOp* uop) { diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h index 8d906ff58e43d..d6001137725d7 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h @@ -49,9 +49,11 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch { void handle(UnaryOp*) final; void handle(BinaryOp*) final; + // TODO: handle swizzle private: std::unordered_map known_values_; + std::unordered_map known_named_scalars_; Fusion* fusion_ = nullptr; FusionPrecomputedValues* evaluator_precomputed_values_ = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 04c367c667275..e4f24f0473a19 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -339,6 +340,20 @@ void Fusion::printKernel(DataType index_type) { std::cout << codegen::generateCudaKernel(GpuLower(this, index_type).kernel()); } +std::unordered_map> Fusion::bankConflictInfo( + DataType index_type) { + GpuLower lower(this, index_type); + auto kernel = lower.kernel(); + auto info = getBankConflictInfo(kernel); + // The container of exprs goes out of scope, so we return a map of string here + std::unordered_map> result; + result.reserve(info.size()); + for (auto i : info) { + result[i.first->toString()] = i.second; + } + return result; +} + void Fusion::printMath(bool from_outputs_only) { FUSER_PERF_SCOPE("Fusion::printMath"); diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index e726d793be756..2c0c59fae2b9b 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -136,6 +136,10 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer { //! Lower the fusion and print a kernel void printKernel(DataType index_type = DataType::Int); + //! Lower the fusion and evaluate bank conflict info + std::unordered_map> bankConflictInfo( + DataType index_type = DataType::Int); + //! Return a list of topologically sorted expressions. This only includes //! exprs required to genereate registered outputs. std::vector exprs(); diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp new file mode 100644 index 0000000000000..2f29d79f26678 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp @@ -0,0 +1,159 @@ +#include + +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace { + +bool isSmemTensorIndex(Val* value) { + return value->isA() && + value->as()->view()->getMemoryType() == + MemoryType::Shared; +} + +int64_t getVectorizeSize(kir::TensorIndex* ti) { + for (auto id : ti->view()->domain()->domain()) { + if (!isParallelTypeVectorize(id->getParallelType())) { + continue; + } + + ExpressionEvaluator expr_eval(id->fusion()); + auto vector_size_optional = expr_eval.evaluate(id->extent()); + + TORCH_INTERNAL_ASSERT( + vector_size_optional.has_value(), + "Could not evaluate constant value bound to vectorized dim."); + + return vector_size_optional->as(); + } + return 1; +} + +inline int64_t getPhaseSize(int64_t word_size_bytes) { + if (word_size_bytes == 16) { + return 8; + } + if (word_size_bytes == 8) { + return 16; + } + return 32; +} + +std::vector evaluateAddressesOnFirstPhase( + kir::TensorIndex* ti, + const std::vector& for_loops) { + std::vector addresses; + const auto word_size_bytes = + dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti); + int64_t phase_size = getPhaseSize(word_size_bytes); + + for (auto tidx : c10::irange(phase_size)) { + int64_t index = 0; + ExpressionEvaluator expr_eval(ti->fusion()); + for (auto fl : for_loops) { + if (fl->index()->isA() && + fl->index()->as()->name() == "threadIdx.x") { + expr_eval.bind(fl->index(), tidx); + } else { + expr_eval.bind(fl->index(), 0); + } + } + for (auto ind : ti->indices()) { + index += expr_eval.evaluate(ind)->as(); + } + addresses.emplace_back(index * word_size_bytes); + } + return addresses; +} + +int getConflictWays(const std::vector& addresses) { + std::unordered_set words_by_bank[32]; + for (auto addr : addresses) { + int64_t word = addr / 4; + int64_t bank = word % 32; + words_by_bank[bank].insert(word); + } + int conflict = 1; + for (const auto& words : words_by_bank) { + conflict = std::max(conflict, words.size()); + } + return conflict; +} + +} // namespace + +class BankConflictInfo : public kir::IrVisitor { + public: + static std::unordered_map> get( + const std::vector& exprs) { + return BankConflictInfo(exprs).bank_conflict_info_; + } + + private: + BankConflictInfo(const std::vector& exprs) { + handle(exprs); + } + + using kir::IrVisitor::handle; + + void handle(Expr* expr) final { + if (expr->isA() || expr->isA()) { + kir::IrVisitor::handle(expr); + return; + } + + if (expr->isA()) { + auto uop = expr->as(); + if (uop->getUnaryOpType() != UnaryOpType::Set) { + return; + } + std::pair conflict_ways{0, 0}; + if (isSmemTensorIndex(uop->in())) { + conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( + uop->in()->as(), for_loops_)); + } + if (isSmemTensorIndex(uop->out())) { + conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( + uop->out()->as(), for_loops_)); + } + if (conflict_ways.first > 1 || conflict_ways.second > 1) { + bank_conflict_info_[expr] = conflict_ways; + } + } else if (expr->isA()) { + auto ldst = expr->as(); + std::pair conflict_ways{0, 0}; + if (isSmemTensorIndex(ldst->in())) { + conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( + ldst->in()->as(), for_loops_)); + } + if (isSmemTensorIndex(ldst->out())) { + conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( + ldst->out()->as(), for_loops_)); + } + if (conflict_ways.first > 1 || conflict_ways.second > 1) { + bank_conflict_info_[expr] = conflict_ways; + } + } + } + + std::unordered_map> bank_conflict_info_; +}; + +std::unordered_map> getBankConflictInfo( + kir::Kernel* kernel) { + return BankConflictInfo::get(kernel->topLevelExprs()); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h new file mode 100644 index 0000000000000..12c12d4bff4d8 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// for more info on shared memory access see page 54-72 of: +// https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf + +// Warning: The bank confliction checking utility here is not a replacement of +// nsight compute. This utility currently has the following assumptions and +// limitations: +// +// 1. This utility assumes that `blockDim.x` is large enough to hold one phase +// 2. This utility assumes that the address only depends on loop variables +// (there can not be a thing like `T0.stride[0]`, `blockDim.x`) +// 3. This utility assumes that the data of the tensor is accessed by +// `T0[index]`, where `index` is the one stored in the `TensorIndex` +// object. +// 4. This utility only checks the first iteration, and the start of all +// loop variables are assumed to be `0` (if we have something like +// `T1_s[tidx, 5]`, then different iterations should have different +// results, which this utility will not be able to handle all of them now) +// 5. This utility assumes that all tensors are independent, which means: +// 5.1 All shared memory tensors are allocated starting from a multiple of +// 4*32 bytes +// 5.2 The only source of bank confliction is from within a tensor. +// There is no bank conflict between different tensors. +// +// Also note that this utility will not provide accurate estimation if the above +// assumptions are satisfied + +std::unordered_map> getBankConflictInfo( + kir::Kernel* kernel); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp index ad43b0ed4e07d..8c00fea08489c 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp @@ -1023,6 +1023,122 @@ TEST_F(NVFuserTest, FusionScheduleTranspose2DSmallInnerSize_CUDA) { } } +TEST_F(NVFuserTest, FusionTransposeBankConflict1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 0, 1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + for (auto info : bank_conflict_info) { + std::pair expect{32, 0}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 0, 1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(0)->parallelize(ParallelType::TIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + for (auto info : bank_conflict_info) { + std::pair expect{0, 32}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({32, 32}, DataType::Bool); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 0, 1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + for (auto info : bank_conflict_info) { + std::pair expect{8, 0}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 0, 1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->merge(0); + tv1->split(0, 4); + tv1->split(0, 8); + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + tv1->axis(0)->parallelize(ParallelType::TIDx); + // T1 [TIDx(32), 8, V(4)] + + tv2->setMemoryType(MemoryType::Shared); + tv2->merge(0); + tv2->split(0, 4); + tv2->split(0, 32); + tv2->axis(1)->parallelize(ParallelType::TIDx); + // T2 [8, TIDx(32), 4] + + tv3->merge(0); + tv3->split(0, 2); + tv3->split(0, 32); + tv3->axis(1)->parallelize(ParallelType::TIDx); + // T3 [16, TIDx(32), 2] + + auto bank_conflict_info = fusion.bankConflictInfo(); + + for (auto info : bank_conflict_info) { + std::pair expect1{0, 8}; + std::pair expect2{8, 4}; + std::pair expect3{2, 0}; + TORCH_CHECK( + info.second == expect1 || info.second == expect2 || + info.second == expect3); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp index 101aac58479ad..40fd6074ccc50 100644 --- a/torch/csrc/jit/codegen/cuda/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/utils.cpp @@ -39,7 +39,8 @@ auto parseDebugDumpOptions() { {DebugDumpOption::PerfDebugVerbose, false}, {DebugDumpOption::TransformPropagator, false}, {DebugDumpOption::Cubin, false}, - {DebugDumpOption::Ptx, false}}; + {DebugDumpOption::Ptx, false}, + {DebugDumpOption::BankConflictInfo, false}}; if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) { c10::string_view options_view(dump_options); @@ -94,6 +95,8 @@ auto parseDebugDumpOptions() { options_map[DebugDumpOption::Cubin] = true; } else if (token == "ptx") { options_map[DebugDumpOption::Ptx] = true; + } else if (token == "bank_conflict") { + options_map[DebugDumpOption::BankConflictInfo] = true; } else { TORCH_CHECK( false, @@ -105,7 +108,7 @@ auto parseDebugDumpOptions() { "\tkernel_args, dump_eff_bandwidth, draw_segmented_fusion,\n", "\tscheduler_params, parallel_dimensions, buffer_reuse_verbose,\n", "\tptxas_verbose, halo, segmenter_logging, perf_debug_verbose\n", - "\ttransform_propagator, cubin, ptx\n"); + "\ttransform_propagator, cubin, ptx, bank_conflict\n"); } options_view = (end_pos != c10::string_view::npos) ? options_view.substr(end_pos + 1) diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h index 5e69ac2bb22b9..71430c1f20514 100644 --- a/torch/csrc/jit/codegen/cuda/utils.h +++ b/torch/csrc/jit/codegen/cuda/utils.h @@ -55,7 +55,8 @@ enum class DebugDumpOption { TransformPropagator, //! When running TransformPropagator, print propagation //! path and replay result Cubin, //! Dump compiled CUBIN - Ptx //! Dump compiled PTX + Ptx, //! Dump compiled PTX + BankConflictInfo //! Dump bank confliction info }; TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);