@@ -943,6 +943,34 @@ void initializeCudaContext() {
943
943
}
944
944
}
945
945
946
+ namespace {
947
+
948
+ // Dump PTX or CUBIN to a file
949
+ void dumpCompiledCodeToFile (
950
+ const nvrtcProgram& program,
951
+ int fusion_id,
952
+ bool dump_cubin) {
953
+ const auto getSize = dump_cubin
954
+ ? at::globalContext ().getNVRTC ().nvrtcGetCUBINSize
955
+ : at::globalContext ().getNVRTC ().nvrtcGetPTXSize ;
956
+ const auto getCode = dump_cubin ? at::globalContext ().getNVRTC ().nvrtcGetCUBIN
957
+ : at::globalContext ().getNVRTC ().nvrtcGetPTX ;
958
+ size_t size = 0 ;
959
+ AT_CUDA_NVRTC_CHECK (getSize (program, &size));
960
+ std::vector<char > code (size);
961
+ AT_CUDA_NVRTC_CHECK (getCode (program, code.data ()));
962
+ std::stringstream file_name;
963
+ file_name << " __tmp_kernel" << fusion_id << " ."
964
+ << (dump_cubin ? " cubin" : " ptx" );
965
+ std::cout << " PRINTING: " << file_name.str () << std::endl;
966
+ std::ofstream out (file_name.str ());
967
+ TORCH_INTERNAL_ASSERT (out.is_open ());
968
+ out.write (code.data (), size);
969
+ out.close ();
970
+ }
971
+
972
+ } // namespace
973
+
946
974
std::pair<NvrtcFunction, std::string> nvrtcCompile (
947
975
const std::string& code,
948
976
const std::string& func_name,
@@ -1183,92 +1211,24 @@ std::pair<NvrtcFunction, std::string> nvrtcCompile(
1183
1211
AT_CUDA_NVRTC_CHECK (getFunc (program, ptx.data ()));
1184
1212
}
1185
1213
1186
- NvrtcFunction compiled_kernel_;
1187
-
1188
- // TODO: We do go through different code path, should investigate whether this
1189
- // has an impact on generated binary.
1190
- #ifndef __HIP_PLATFORM_HCC__
1191
- const char * prefix_env = getenv (" PYTORCH_NVFUSER_CUBIN" );
1192
- if (prefix_env) {
1193
- FUSER_PERF_SCOPE (" executor_utils::Nvrtc::LoadCUBIN" );
1194
-
1195
- // Output ptx file
1196
- std::stringstream output_file_name;
1197
- output_file_name << prefix_env << " _" << id
1198
- << (compile_to_sass ? " .cubin" : " .ptx" );
1199
- std::ofstream outputFile (output_file_name.str ().c_str (), std::ios::out);
1200
- if (outputFile.is_open ()) {
1201
- outputFile.write (ptx.data (), ptx.size ());
1202
- outputFile.close ();
1203
- }
1204
-
1205
- if (compile_to_sass) {
1206
- FUSER_PERF_SCOPE (" executor_utils::Nvrtc::LoadPTX" );
1207
-
1208
- // load sass directly
1209
- AT_CUDA_DRIVER_CHECK (at::globalContext ().getNVRTC ().cuModuleLoadDataEx (
1210
- &(compiled_kernel_.module ),
1211
- ptx.data (),
1212
- options.size (),
1213
- options.data (),
1214
- option_vals.data ()));
1215
- } else {
1216
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1217
- CUlinkState linkState;
1218
-
1219
- AT_CUDA_DRIVER_CHECK (at::globalContext ().getNVRTC ().cuLinkCreate (
1220
- // 0, nullptr, nullptr, &linkState));
1221
- options.size (),
1222
- options.data (),
1223
- option_vals.data (),
1224
- &linkState));
1225
-
1226
- AT_CUDA_DRIVER_CHECK (at::globalContext ().getNVRTC ().cuLinkAddData (
1227
- linkState,
1228
- CU_JIT_INPUT_PTX,
1229
- ptx.data (),
1230
- ptx_size,
1231
- " compiling PTX" ,
1232
- 0 ,
1233
- nullptr ,
1234
- nullptr ));
1235
-
1236
- if (isDebugDumpEnabled (DebugDumpOption::PrintPtxasLog)) {
1237
- std::cout << info_log.data () << std::endl;
1238
- }
1239
-
1240
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1241
- size_t cubinSize;
1242
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1243
- void * cubin;
1244
- AT_CUDA_DRIVER_CHECK (at::globalContext ().getNVRTC ().cuLinkComplete (
1245
- linkState, &cubin, &cubinSize));
1214
+ if (isDebugDumpEnabled (DebugDumpOption::Ptx)) {
1215
+ dumpCompiledCodeToFile (program, id, false );
1216
+ }
1246
1217
1247
- // Output binary file
1248
- std::stringstream cubin_file_name;
1249
- cubin_file_name << prefix_env << " _" << id << " .cubin" ;
1218
+ if (isDebugDumpEnabled (DebugDumpOption::Cubin)) {
1219
+ TORCH_INTERNAL_ASSERT (
1220
+ compile_to_sass,
1221
+ " CUBIN not available as the kernel was compiled only to PTX" );
1222
+ dumpCompiledCodeToFile (program, id, true );
1223
+ }
1250
1224
1251
- std::ofstream myCubinFile (
1252
- cubin_file_name.str ().c_str (), std::ios::out | std::ios::binary);
1225
+ NvrtcFunction compiled_kernel_;
1253
1226
1254
- if (myCubinFile.is_open ()) {
1255
- myCubinFile.write (static_cast <const char *>(cubin), cubinSize);
1256
- myCubinFile.close ();
1257
- }
1258
- // load compiled cubin
1259
- // AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData(
1260
- // &(compiled_kernel_.module), cubin));
1261
- AT_CUDA_DRIVER_CHECK (at::globalContext ().getNVRTC ().cuModuleLoadDataEx (
1262
- &(compiled_kernel_.module ),
1263
- cubin,
1264
- options.size (),
1265
- options.data (),
1266
- option_vals.data ()));
1267
- }
1268
- } else {
1227
+ #ifndef __HIP_PLATFORM_HCC__
1228
+ {
1269
1229
FUSER_PERF_SCOPE (" executor_utils::Nvrtc::LoadPTX" );
1270
1230
1271
- // load ptx directly
1231
+ // load ptx or cubin directly
1272
1232
AT_CUDA_DRIVER_CHECK (at::globalContext ().getNVRTC ().cuModuleLoadDataEx (
1273
1233
&(compiled_kernel_.module ),
1274
1234
ptx.data (),
0 commit comments