@@ -506,9 +506,9 @@ def replace_math_functions(input_string):
506
506
Plan is to remove this function once HIP supports std:: math function calls inside device code
507
507
"""
508
508
output_string = input_string
509
- output_string = re . sub ( "std::exp\(" , "::exp(" , output_string )
510
- output_string = re . sub ( "std::log\(" , "::log(" , output_string )
511
- output_string = re . sub ( "std::pow\(" , "::pow(" , output_string )
509
+ for func in MATH_TRANSPILATIONS :
510
+ output_string = output_string . replace ( r'{}(' . format ( func ), '{}(' . format ( MATH_TRANSPILATIONS [ func ]) )
511
+
512
512
return output_string
513
513
514
514
@@ -763,10 +763,8 @@ def preprocessor(filepath, stats, hipify_caffe2):
763
763
# output_source = disable_asserts(output_source)
764
764
765
765
# Replace std:: with non-std:: versions
766
- output_source = replace_math_functions (output_source )
767
-
768
- # Replace std:: with non-std:: versions
769
- output_source = transpile_device_math (output_source )
766
+ if re .search (r"\.cu$" , filepath ) or re .search (r"\.cuh$" , filepath ):
767
+ output_source = replace_math_functions (output_source )
770
768
771
769
# Replace __forceinline__ with inline
772
770
output_source = replace_forceinline (output_source )
@@ -947,31 +945,6 @@ def disable_module(input_file):
947
945
f .truncate ()
948
946
949
947
950
- def transpile_device_math (input_string ):
951
- """ Temporarily replace std:: invocations of math functions with non-std:: versions."""
952
- # Extract device code positions
953
- get_kernel_definitions = [k for k in re .finditer ( r"(template[ ]*<(.*)>\n.*\n?)?(__global__|__device__) void[\n| ](\w+(\(.*\))?)\(" , input_string )]
954
-
955
- # Prepare output
956
- output_string = input_string
957
-
958
- # Iterate through each kernel definition
959
- for kernel in get_kernel_definitions :
960
- # Find the final paranthesis that closes this kernel function definition.
961
- _ , paranth_end = find_bracket_group (input_string , kernel .end () - 1 )
962
-
963
- # Replace all std:: math functions within range [start...ending]
964
- selection = input_string [kernel .start ():paranth_end + 1 ]
965
- selection_transpiled = selection
966
- for func in MATH_TRANSPILATIONS :
967
- selection_transpiled = selection_transpiled .replace (func , MATH_TRANSPILATIONS [func ])
968
-
969
- # Perform replacements inside the output_string
970
- output_string = output_string .replace (selection , selection_transpiled )
971
-
972
- return output_string
973
-
974
-
975
948
def extract_arguments (start , string ):
976
949
""" Return the list of arguments in the upcoming function parameter closure.
977
950
Example:
0 commit comments