diff --git a/tools/amd_build/pyHIPIFY/hipify-python.py b/tools/amd_build/pyHIPIFY/hipify-python.py index bd38e26949492d..7aae2481099fd4 100755 --- a/tools/amd_build/pyHIPIFY/hipify-python.py +++ b/tools/amd_build/pyHIPIFY/hipify-python.py @@ -331,6 +331,26 @@ def disable_asserts(input_string): output_string = output_string.replace(input_string[start:p_end + 1], "") return output_string +def replace_forceinline(input_string): + """__forceinline__'d methods can cause 'symbol multiply defined' errors in HIP. + Adding 'static' to all such methods leads to compilation errors, so + replacing '__forceinline__' with 'inline' as a workaround + https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_faq.md#what-if-hip-generates-error-of-symbol-multiply-defined-only-on-amd-machine + """ + output_string = input_string + output_string = re.sub("__forceinline__", "inline", output_string) + return output_string + +def replace_math_functions(input_string): + """ FIXME: Temporarily replace std:: invocations of math functions with non-std:: versions to prevent linker errors + NOTE: This can lead to correctness issues when running tests, since the correct version of the math function (exp/expf) might not get called. + Plan is to remove this function once HIP supports std:: math function calls inside device code + """ + output_string = input_string + output_string = re.sub("std::exp\(", "::exp(", output_string) + output_string = re.sub("std::log\(", "::log(", output_string) + output_string = re.sub("std::pow\(", "::pow(", output_string) + return output_string def disable_function(input_string, function, replace_style): """ Finds and disables a function in a particular file. @@ -497,6 +517,12 @@ def preprocessor(filepath, stats): if not filepath.endswith("THCGeneral.h.in"): output_source = disable_asserts(output_source) + # Replace std:: with non-std:: versions + output_source = replace_math_functions(output_source) + + # Replace __forceinline__ with inline + output_source = replace_forceinline(output_source) + # Overwrite file contents fileobj.seek(0) fileobj.write(output_source)