Skip to content

Remove duplicate math transpilation function #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 5 additions & 32 deletions tools/amd_build/pyHIPIFY/hipify-python.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,9 +507,9 @@ def replace_math_functions(input_string):
Plan is to remove this function once HIP supports std:: math function calls inside device code
"""
output_string = input_string
output_string = re.sub("std::exp\(", "::exp(", output_string)
output_string = re.sub("std::log\(", "::log(", output_string)
output_string = re.sub("std::pow\(", "::pow(", output_string)
for func in MATH_TRANSPILATIONS:
output_string = output_string.replace(r'{}('.format(func), '{}('.format(MATH_TRANSPILATIONS[func]))

return output_string


Expand Down Expand Up @@ -764,10 +764,8 @@ def preprocessor(filepath, stats, hipify_caffe2):
# output_source = disable_asserts(output_source)

# Replace std:: with non-std:: versions
output_source = replace_math_functions(output_source)

# Replace std:: with non-std:: versions
output_source = transpile_device_math(output_source)
if re.search(r"\.cu$", filepath) or re.search(r"\.cuh$", filepath):
output_source = replace_math_functions(output_source)

# Replace __forceinline__ with inline
output_source = replace_forceinline(output_source)
Expand Down Expand Up @@ -948,31 +946,6 @@ def disable_module(input_file):
f.truncate()


def transpile_device_math(input_string):
""" Temporarily replace std:: invocations of math functions with non-std:: versions."""
# Extract device code positions
get_kernel_definitions = [k for k in re.finditer( r"(template[ ]*<(.*)>\n.*\n?)?(__global__|__device__) void[\n| ](\w+(\(.*\))?)\(", input_string)]

# Prepare output
output_string = input_string

# Iterate through each kernel definition
for kernel in get_kernel_definitions:
# Find the final paranthesis that closes this kernel function definition.
_, paranth_end = find_bracket_group(input_string, kernel.end() - 1)

# Replace all std:: math functions within range [start...ending]
selection = input_string[kernel.start():paranth_end + 1]
selection_transpiled = selection
for func in MATH_TRANSPILATIONS:
selection_transpiled = selection_transpiled.replace(func, MATH_TRANSPILATIONS[func])

# Perform replacements inside the output_string
output_string = output_string.replace(selection, selection_transpiled)

return output_string


def extract_arguments(start, string):
""" Return the list of arguments in the upcoming function parameter closure.
Example:
Expand Down