diff --git a/tools/amd_build/pyHIPIFY/hipify-python.py b/tools/amd_build/pyHIPIFY/hipify-python.py index 288809f43f5282..13cc4c36d79991 100755 --- a/tools/amd_build/pyHIPIFY/hipify-python.py +++ b/tools/amd_build/pyHIPIFY/hipify-python.py @@ -507,9 +507,9 @@ def replace_math_functions(input_string): Plan is to remove this function once HIP supports std:: math function calls inside device code """ output_string = input_string - output_string = re.sub("std::exp\(", "::exp(", output_string) - output_string = re.sub("std::log\(", "::log(", output_string) - output_string = re.sub("std::pow\(", "::pow(", output_string) + for func in MATH_TRANSPILATIONS: + output_string = output_string.replace(r'{}('.format(func), '{}('.format(MATH_TRANSPILATIONS[func])) + return output_string @@ -764,10 +764,8 @@ def preprocessor(filepath, stats, hipify_caffe2): # output_source = disable_asserts(output_source) # Replace std:: with non-std:: versions - output_source = replace_math_functions(output_source) - - # Replace std:: with non-std:: versions - output_source = transpile_device_math(output_source) + if re.search(r"\.cu$", filepath) or re.search(r"\.cuh$", filepath): + output_source = replace_math_functions(output_source) # Replace __forceinline__ with inline output_source = replace_forceinline(output_source) @@ -948,31 +946,6 @@ def disable_module(input_file): f.truncate() -def transpile_device_math(input_string): - """ Temporarily replace std:: invocations of math functions with non-std:: versions.""" - # Extract device code positions - get_kernel_definitions = [k for k in re.finditer( r"(template[ ]*<(.*)>\n.*\n?)?(__global__|__device__) void[\n| ](\w+(\(.*\))?)\(", input_string)] - - # Prepare output - output_string = input_string - - # Iterate through each kernel definition - for kernel in get_kernel_definitions: - # Find the final paranthesis that closes this kernel function definition. - _, paranth_end = find_bracket_group(input_string, kernel.end() - 1) - - # Replace all std:: math functions within range [start...ending] - selection = input_string[kernel.start():paranth_end + 1] - selection_transpiled = selection - for func in MATH_TRANSPILATIONS: - selection_transpiled = selection_transpiled.replace(func, MATH_TRANSPILATIONS[func]) - - # Perform replacements inside the output_string - output_string = output_string.replace(selection, selection_transpiled) - - return output_string - - def extract_arguments(start, string): """ Return the list of arguments in the upcoming function parameter closure. Example: