From aa3c6f087fccf69fbe1bcc7825ae6037fb533ad2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Aug 2022 09:49:30 -0700 Subject: [PATCH] Add PYTORCH_NVFUSER_DUMP options to save PTX and CUBIN Also cleaned up the relevant code path a little --- .../csrc/jit/codegen/cuda/executor_utils.cpp | 122 ++++++------------ .../codegen/cuda/runtime/random_numbers.cu | 8 +- torch/csrc/jit/codegen/cuda/utils.cpp | 10 +- torch/csrc/jit/codegen/cuda/utils.h | 2 + 4 files changed, 56 insertions(+), 86 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 05f78434ed1c..b0f35962c5dd 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -943,6 +943,34 @@ void initializeCudaContext() { } } +namespace { + +// Dump PTX or CUBIN to a file +void dumpCompiledCodeToFile( + const nvrtcProgram& program, + int fusion_id, + bool dump_cubin) { + const auto getSize = dump_cubin + ? at::globalContext().getNVRTC().nvrtcGetCUBINSize + : at::globalContext().getNVRTC().nvrtcGetPTXSize; + const auto getCode = dump_cubin ? at::globalContext().getNVRTC().nvrtcGetCUBIN + : at::globalContext().getNVRTC().nvrtcGetPTX; + size_t size = 0; + AT_CUDA_NVRTC_CHECK(getSize(program, &size)); + std::vector code(size); + AT_CUDA_NVRTC_CHECK(getCode(program, code.data())); + std::stringstream file_name; + file_name << "__tmp_kernel" << fusion_id << "." + << (dump_cubin ? "cubin" : "ptx"); + std::cout << "PRINTING: " << file_name.str() << std::endl; + std::ofstream out(file_name.str()); + TORCH_INTERNAL_ASSERT(out.is_open()); + out.write(code.data(), size); + out.close(); +} + +} // namespace + std::pair nvrtcCompile( const std::string& code, const std::string& func_name, @@ -1183,92 +1211,24 @@ std::pair nvrtcCompile( AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data())); } - NvrtcFunction compiled_kernel_; - - // TODO: We do go through different code path, should investigate whether this - // has an impact on generated binary. -#ifndef __HIP_PLATFORM_HCC__ - const char* prefix_env = getenv("PYTORCH_NVFUSER_CUBIN"); - if (prefix_env) { - FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadCUBIN"); - - // Output ptx file - std::stringstream output_file_name; - output_file_name << prefix_env << "_" << id - << (compile_to_sass ? ".cubin" : ".ptx"); - std::ofstream outputFile(output_file_name.str().c_str(), std::ios::out); - if (outputFile.is_open()) { - outputFile.write(ptx.data(), ptx.size()); - outputFile.close(); - } - - if (compile_to_sass) { - FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX"); - - // load sass directly - AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx( - &(compiled_kernel_.module), - ptx.data(), - options.size(), - options.data(), - option_vals.data())); - } else { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - CUlinkState linkState; - - AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkCreate( - // 0, nullptr, nullptr, &linkState)); - options.size(), - options.data(), - option_vals.data(), - &linkState)); - - AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkAddData( - linkState, - CU_JIT_INPUT_PTX, - ptx.data(), - ptx_size, - "compiling PTX", - 0, - nullptr, - nullptr)); - - if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) { - std::cout << info_log.data() << std::endl; - } - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t cubinSize; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - void* cubin; - AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkComplete( - linkState, &cubin, &cubinSize)); + if (isDebugDumpEnabled(DebugDumpOption::Ptx)) { + dumpCompiledCodeToFile(program, id, false); + } - // Output binary file - std::stringstream cubin_file_name; - cubin_file_name << prefix_env << "_" << id << ".cubin"; + if (isDebugDumpEnabled(DebugDumpOption::Cubin)) { + TORCH_INTERNAL_ASSERT( + compile_to_sass, + "CUBIN not available as the kernel was compiled only to PTX"); + dumpCompiledCodeToFile(program, id, true); + } - std::ofstream myCubinFile( - cubin_file_name.str().c_str(), std::ios::out | std::ios::binary); + NvrtcFunction compiled_kernel_; - if (myCubinFile.is_open()) { - myCubinFile.write(static_cast(cubin), cubinSize); - myCubinFile.close(); - } - // load compiled cubin - // AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData( - // &(compiled_kernel_.module), cubin)); - AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx( - &(compiled_kernel_.module), - cubin, - options.size(), - options.data(), - option_vals.data())); - } - } else { +#ifndef __HIP_PLATFORM_HCC__ + { FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX"); - // load ptx directly + // load ptx or cubin directly AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx( &(compiled_kernel_.module), ptx.data(), diff --git a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu index e26b99c06844..4736d8ac6b17 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu @@ -56,10 +56,12 @@ __device__ double uniform(unsigned int x, unsigned int y) { return z * kRan2Pow53Inv + (kRan2Pow53Inv / 2.0); } -__device__ double randLike(const uint4 &rng_result, int rng_component) { - return uniform((&rng_result.x)[rng_component * 2], (&rng_result.x)[rng_component * 2 + 1]); +__device__ double randLike(const uint4& rng_result, int rng_component) { + return uniform( + (&rng_result.x)[rng_component * 2], + (&rng_result.x)[rng_component * 2 + 1]); } -__device__ float randLikef(const uint4 &rng_result, int rng_component) { +__device__ float randLikef(const uint4& rng_result, int rng_component) { return uniformf((&rng_result.x)[rng_component]); } diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp index fcb818726c85..ab275f324a39 100644 --- a/torch/csrc/jit/codegen/cuda/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/utils.cpp @@ -38,7 +38,9 @@ auto parseDebugDumpOptions() { {DebugDumpOption::Halo, false}, {DebugDumpOption::PerfDebugVerbose, false}, {DebugDumpOption::TransformPropagator, false}, - {DebugDumpOption::InlinePropagator, false}}; + {DebugDumpOption::InlinePropagator, false}, + {DebugDumpOption::Cubin, false}, + {DebugDumpOption::Ptx, false}}; if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) { c10::string_view options_view(dump_options); @@ -91,6 +93,10 @@ auto parseDebugDumpOptions() { options_map[DebugDumpOption::TransformPropagator] = true; } else if (token == "inline_propagator") { options_map[DebugDumpOption::InlinePropagator] = true; + } else if (token == "cubin") { + options_map[DebugDumpOption::Cubin] = true; + } else if (token == "ptx") { + options_map[DebugDumpOption::Ptx] = true; } else { TORCH_CHECK( false, @@ -102,7 +108,7 @@ auto parseDebugDumpOptions() { "\tkernel_args, dump_eff_bandwidth, draw_segmented_fusion,\n", "\tscheduler_params, parallel_dimensions, buffer_reuse_verbose,\n", "\tptxas_verbose, halo, segmenter_logging, perf_debug_verbose\n", - "\ttransform_propagator, inline_propagator\n"); + "\ttransform_propagator, inline_propagator, cubin, ptx\n"); } options_view = (end_pos != c10::string_view::npos) ? options_view.substr(end_pos + 1) diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h index 77f5ab89a2f5..20b832153b3d 100644 --- a/torch/csrc/jit/codegen/cuda/utils.h +++ b/torch/csrc/jit/codegen/cuda/utils.h @@ -50,6 +50,8 @@ enum class DebugDumpOption { //! path and replay result InlinePropagator, //! When running InlinePropagator, print propagation //! path and inlining result + Cubin, //! Dump compiled CUBIN + Ptx //! Dump compiled PTX }; TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);