Skip to content

Commit 6e8f953

Browse files
authored
Add PYTORCH_NVFUSER_DUMP options to save PTX and CUBIN (#1916)
Also cleaned up the relevant code path a little
1 parent 5eefa9a commit 6e8f953

File tree

4 files changed

+56
-86
lines changed

4 files changed

+56
-86
lines changed

torch/csrc/jit/codegen/cuda/executor_utils.cpp

Lines changed: 41 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,34 @@ void initializeCudaContext() {
943943
}
944944
}
945945

946+
namespace {
947+
948+
// Dump PTX or CUBIN to a file
949+
void dumpCompiledCodeToFile(
950+
const nvrtcProgram& program,
951+
int fusion_id,
952+
bool dump_cubin) {
953+
const auto getSize = dump_cubin
954+
? at::globalContext().getNVRTC().nvrtcGetCUBINSize
955+
: at::globalContext().getNVRTC().nvrtcGetPTXSize;
956+
const auto getCode = dump_cubin ? at::globalContext().getNVRTC().nvrtcGetCUBIN
957+
: at::globalContext().getNVRTC().nvrtcGetPTX;
958+
size_t size = 0;
959+
AT_CUDA_NVRTC_CHECK(getSize(program, &size));
960+
std::vector<char> code(size);
961+
AT_CUDA_NVRTC_CHECK(getCode(program, code.data()));
962+
std::stringstream file_name;
963+
file_name << "__tmp_kernel" << fusion_id << "."
964+
<< (dump_cubin ? "cubin" : "ptx");
965+
std::cout << "PRINTING: " << file_name.str() << std::endl;
966+
std::ofstream out(file_name.str());
967+
TORCH_INTERNAL_ASSERT(out.is_open());
968+
out.write(code.data(), size);
969+
out.close();
970+
}
971+
972+
} // namespace
973+
946974
std::pair<NvrtcFunction, std::string> nvrtcCompile(
947975
const std::string& code,
948976
const std::string& func_name,
@@ -1183,92 +1211,24 @@ std::pair<NvrtcFunction, std::string> nvrtcCompile(
11831211
AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));
11841212
}
11851213

1186-
NvrtcFunction compiled_kernel_;
1187-
1188-
// TODO: We do go through different code path, should investigate whether this
1189-
// has an impact on generated binary.
1190-
#ifndef __HIP_PLATFORM_HCC__
1191-
const char* prefix_env = getenv("PYTORCH_NVFUSER_CUBIN");
1192-
if (prefix_env) {
1193-
FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadCUBIN");
1194-
1195-
// Output ptx file
1196-
std::stringstream output_file_name;
1197-
output_file_name << prefix_env << "_" << id
1198-
<< (compile_to_sass ? ".cubin" : ".ptx");
1199-
std::ofstream outputFile(output_file_name.str().c_str(), std::ios::out);
1200-
if (outputFile.is_open()) {
1201-
outputFile.write(ptx.data(), ptx.size());
1202-
outputFile.close();
1203-
}
1204-
1205-
if (compile_to_sass) {
1206-
FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX");
1207-
1208-
// load sass directly
1209-
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx(
1210-
&(compiled_kernel_.module),
1211-
ptx.data(),
1212-
options.size(),
1213-
options.data(),
1214-
option_vals.data()));
1215-
} else {
1216-
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1217-
CUlinkState linkState;
1218-
1219-
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkCreate(
1220-
// 0, nullptr, nullptr, &linkState));
1221-
options.size(),
1222-
options.data(),
1223-
option_vals.data(),
1224-
&linkState));
1225-
1226-
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkAddData(
1227-
linkState,
1228-
CU_JIT_INPUT_PTX,
1229-
ptx.data(),
1230-
ptx_size,
1231-
"compiling PTX",
1232-
0,
1233-
nullptr,
1234-
nullptr));
1235-
1236-
if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
1237-
std::cout << info_log.data() << std::endl;
1238-
}
1239-
1240-
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1241-
size_t cubinSize;
1242-
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1243-
void* cubin;
1244-
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkComplete(
1245-
linkState, &cubin, &cubinSize));
1214+
if (isDebugDumpEnabled(DebugDumpOption::Ptx)) {
1215+
dumpCompiledCodeToFile(program, id, false);
1216+
}
12461217

1247-
// Output binary file
1248-
std::stringstream cubin_file_name;
1249-
cubin_file_name << prefix_env << "_" << id << ".cubin";
1218+
if (isDebugDumpEnabled(DebugDumpOption::Cubin)) {
1219+
TORCH_INTERNAL_ASSERT(
1220+
compile_to_sass,
1221+
"CUBIN not available as the kernel was compiled only to PTX");
1222+
dumpCompiledCodeToFile(program, id, true);
1223+
}
12501224

1251-
std::ofstream myCubinFile(
1252-
cubin_file_name.str().c_str(), std::ios::out | std::ios::binary);
1225+
NvrtcFunction compiled_kernel_;
12531226

1254-
if (myCubinFile.is_open()) {
1255-
myCubinFile.write(static_cast<const char*>(cubin), cubinSize);
1256-
myCubinFile.close();
1257-
}
1258-
// load compiled cubin
1259-
// AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData(
1260-
// &(compiled_kernel_.module), cubin));
1261-
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx(
1262-
&(compiled_kernel_.module),
1263-
cubin,
1264-
options.size(),
1265-
options.data(),
1266-
option_vals.data()));
1267-
}
1268-
} else {
1227+
#ifndef __HIP_PLATFORM_HCC__
1228+
{
12691229
FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX");
12701230

1271-
// load ptx directly
1231+
// load ptx or cubin directly
12721232
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx(
12731233
&(compiled_kernel_.module),
12741234
ptx.data(),

torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ __device__ double uniform(unsigned int x, unsigned int y) {
5656
return z * kRan2Pow53Inv + (kRan2Pow53Inv / 2.0);
5757
}
5858

59-
__device__ double randLike(const uint4 &rng_result, int rng_component) {
60-
return uniform((&rng_result.x)[rng_component * 2], (&rng_result.x)[rng_component * 2 + 1]);
59+
__device__ double randLike(const uint4& rng_result, int rng_component) {
60+
return uniform(
61+
(&rng_result.x)[rng_component * 2],
62+
(&rng_result.x)[rng_component * 2 + 1]);
6163
}
6264

63-
__device__ float randLikef(const uint4 &rng_result, int rng_component) {
65+
__device__ float randLikef(const uint4& rng_result, int rng_component) {
6466
return uniformf((&rng_result.x)[rng_component]);
6567
}

torch/csrc/jit/codegen/cuda/utils.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ auto parseDebugDumpOptions() {
3838
{DebugDumpOption::Halo, false},
3939
{DebugDumpOption::PerfDebugVerbose, false},
4040
{DebugDumpOption::TransformPropagator, false},
41-
{DebugDumpOption::InlinePropagator, false}};
41+
{DebugDumpOption::InlinePropagator, false},
42+
{DebugDumpOption::Cubin, false},
43+
{DebugDumpOption::Ptx, false}};
4244

4345
if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) {
4446
c10::string_view options_view(dump_options);
@@ -91,6 +93,10 @@ auto parseDebugDumpOptions() {
9193
options_map[DebugDumpOption::TransformPropagator] = true;
9294
} else if (token == "inline_propagator") {
9395
options_map[DebugDumpOption::InlinePropagator] = true;
96+
} else if (token == "cubin") {
97+
options_map[DebugDumpOption::Cubin] = true;
98+
} else if (token == "ptx") {
99+
options_map[DebugDumpOption::Ptx] = true;
94100
} else {
95101
TORCH_CHECK(
96102
false,
@@ -102,7 +108,7 @@ auto parseDebugDumpOptions() {
102108
"\tkernel_args, dump_eff_bandwidth, draw_segmented_fusion,\n",
103109
"\tscheduler_params, parallel_dimensions, buffer_reuse_verbose,\n",
104110
"\tptxas_verbose, halo, segmenter_logging, perf_debug_verbose\n",
105-
"\ttransform_propagator, inline_propagator\n");
111+
"\ttransform_propagator, inline_propagator, cubin, ptx\n");
106112
}
107113
options_view = (end_pos != c10::string_view::npos)
108114
? options_view.substr(end_pos + 1)

torch/csrc/jit/codegen/cuda/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ enum class DebugDumpOption {
5050
//! path and replay result
5151
InlinePropagator, //! When running InlinePropagator, print propagation
5252
//! path and inlining result
53+
Cubin, //! Dump compiled CUBIN
54+
Ptx //! Dump compiled PTX
5355
};
5456

5557
TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);

0 commit comments

Comments
 (0)