Skip to content

Commit 49b5097

Browse files
jithunnair-amdiotamudelta
authored andcommitted
Remove duplicate math transpilation function (ROCm#233)
* Remove duplicate math transpilation function * Modify regex to expand matches to more __device__ functions * Try a different tack. Apply math transpilations only to .cu and .cuh files * Undo change that's not required anymore since we're not using regex to detect device functions
1 parent dd2c487 commit 49b5097

File tree

1 file changed

+5
-32
lines changed

1 file changed

+5
-32
lines changed

tools/amd_build/pyHIPIFY/hipify-python.py

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,9 @@ def replace_math_functions(input_string):
506506
Plan is to remove this function once HIP supports std:: math function calls inside device code
507507
"""
508508
output_string = input_string
509-
output_string = re.sub("std::exp\(", "::exp(", output_string)
510-
output_string = re.sub("std::log\(", "::log(", output_string)
511-
output_string = re.sub("std::pow\(", "::pow(", output_string)
509+
for func in MATH_TRANSPILATIONS:
510+
output_string = output_string.replace(r'{}('.format(func), '{}('.format(MATH_TRANSPILATIONS[func]))
511+
512512
return output_string
513513

514514

@@ -763,10 +763,8 @@ def preprocessor(filepath, stats, hipify_caffe2):
763763
# output_source = disable_asserts(output_source)
764764

765765
# Replace std:: with non-std:: versions
766-
output_source = replace_math_functions(output_source)
767-
768-
# Replace std:: with non-std:: versions
769-
output_source = transpile_device_math(output_source)
766+
if re.search(r"\.cu$", filepath) or re.search(r"\.cuh$", filepath):
767+
output_source = replace_math_functions(output_source)
770768

771769
# Replace __forceinline__ with inline
772770
output_source = replace_forceinline(output_source)
@@ -947,31 +945,6 @@ def disable_module(input_file):
947945
f.truncate()
948946

949947

950-
def transpile_device_math(input_string):
951-
""" Temporarily replace std:: invocations of math functions with non-std:: versions."""
952-
# Extract device code positions
953-
get_kernel_definitions = [k for k in re.finditer( r"(template[ ]*<(.*)>\n.*\n?)?(__global__|__device__) void[\n| ](\w+(\(.*\))?)\(", input_string)]
954-
955-
# Prepare output
956-
output_string = input_string
957-
958-
# Iterate through each kernel definition
959-
for kernel in get_kernel_definitions:
960-
# Find the final paranthesis that closes this kernel function definition.
961-
_, paranth_end = find_bracket_group(input_string, kernel.end() - 1)
962-
963-
# Replace all std:: math functions within range [start...ending]
964-
selection = input_string[kernel.start():paranth_end + 1]
965-
selection_transpiled = selection
966-
for func in MATH_TRANSPILATIONS:
967-
selection_transpiled = selection_transpiled.replace(func, MATH_TRANSPILATIONS[func])
968-
969-
# Perform replacements inside the output_string
970-
output_string = output_string.replace(selection, selection_transpiled)
971-
972-
return output_string
973-
974-
975948
def extract_arguments(start, string):
976949
""" Return the list of arguments in the upcoming function parameter closure.
977950
Example:

0 commit comments

Comments
 (0)