diff --git a/tools/amd_build/pyHIPIFY/hipify-python.py b/tools/amd_build/pyHIPIFY/hipify-python.py index 355948433ea7cd..d2f7bbc091a082 100755 --- a/tools/amd_build/pyHIPIFY/hipify-python.py +++ b/tools/amd_build/pyHIPIFY/hipify-python.py @@ -487,6 +487,34 @@ def replace_math_functions(input_string): return output_string +def hip_header_magic(input_string): + """If the file makes kernel builtin calls and does not include the cuda_runtime.h header, + then automatically add an #include to match the "magic" includes provided by NVCC. + TODO: + Update logic to ignore cases where the cuda_runtime.h is included by another file. + """ + + # Copy the input. + output_string = input_string + + # Check if one of the following headers is already included. + headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"] + if any(re.search(r'#include ("{0}"|<{0}>)'.format(ext), output_string) for ext in headers): + return output_string + + # Rough logic to detect if we're inside device code + hasDeviceLogic = "hipLaunchKernelGGL" in output_string + hasDeviceLogic += "__global__" in output_string + hasDeviceLogic += "__shared__" in output_string + hasDeviceLogic += re.search(r"[:]?[:]?\b(__syncthreads)\b(\w*\()", output_string) is not None + + # If device logic found, provide the necessary header. + if hasDeviceLogic: + output_string = '#include "hip/hip_runtime.h"\n' + input_string + + return output_string + + def replace_extern_shared(input_string): """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead. https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#__shared__ @@ -715,6 +743,9 @@ def preprocessor(filepath, stats, hipify_caffe2): # Replace __forceinline__ with inline output_source = replace_forceinline(output_source) + # Include header if device code is contained. + output_source = hip_header_magic(output_source) + # Replace the extern __shared__ output_source = replace_extern_shared(output_source)