diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index a28fb02ed02e23..1329663f8aa89c 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -255,14 +255,25 @@ endif() # ---[ Caffe2 HIP sources. if(USE_ROCM) - # Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs. - IF(BUILD_ATEN) - HIP_INCLUDE_DIRECTORIES(${Caffe2_HIP_INCLUDES}) - ENDIF() + if(BUILD_ATEN) + # Get Compile Definitions from the directory (FindHIP.CMake bug) + get_directory_property(MY_DEFINITIONS COMPILE_DEFINITIONS) + if(MY_DEFINITIONS) + foreach(_item ${MY_DEFINITIONS}) + LIST(APPEND HIP_HCC_FLAGS "-D${_item}") + endforeach() + endif() + + # Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs. + hip_include_directories(${Caffe2_HIP_INCLUDES}) + endif() + IF(BUILD_CAFFE2) set_source_files_properties(${Caffe2_HIP_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) ENDIF() - hip_add_library(caffe2_hip ${Caffe2_HIP_SRCS}) + + # Make sure the SHARED flag is set or else -fPIC is dropped when linking with HIP. + hip_add_library(caffe2_hip SHARED ${Caffe2_HIP_SRCS}) # Since PyTorch files contain HIP headers, these flags are required for the necessary definitions to be added. set_target_properties(caffe2_hip PROPERTIES COMPILE_FLAGS ${HIP_HIPCC_FLAGS}) diff --git a/tools/amd_build/pyHIPIFY/hipify-python.py b/tools/amd_build/pyHIPIFY/hipify-python.py index 15a717c7766cdf..837989a60e1abb 100755 --- a/tools/amd_build/pyHIPIFY/hipify-python.py +++ b/tools/amd_build/pyHIPIFY/hipify-python.py @@ -26,11 +26,13 @@ import argparse import constants +import fnmatch import re import shutil import sys import os import yaml +import ast from functools import reduce from enum import Enum @@ -40,6 +42,7 @@ """This dictionary provides the mapping from PyTorch kernel template types to their actual types.""" PYTORCH_TEMPLATE_MAP = {"Dtype": "real", "T": "real"} +CAFFE2_TEMPLATE_MAP = {} def openf(filename, mode): @@ -210,72 +213,47 @@ def update_progress_bar(total, progress): sys.stderr.flush() -def filename_ends_with_extension(filename, extensions): - """Helper method to see if filename ends with certain extension""" - for ext in extensions: - if filename.endswith("." + ext): - return True +def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), hipify_caffe2=False): + def _fnmatch(filepath, patterns): + return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns) - return False + def match_extensions(filename): + """Helper method to see if filename ends with certain extension""" + return os.path.splitext(filename)[1] in extensions + for (dirpath, _, filenames) in os.walk(root_path, topdown=True): + for fn in filenames: + filepath = os.path.join(dirpath, fn) + rel_filepath = os.path.relpath(filepath, root_path) + if _fnmatch(rel_filepath, includes) and (not _fnmatch(rel_filepath, ignores)) and match_extensions(fn): + if hipify_caffe2 and not is_caffe2_gpu_file(filepath): + continue -def inside_included_directories(dirpath, rootpath, include_dirs): - """Helper method to see if filename within included directories""" - for included_directory in include_dirs: - if re.match(r'{0}\b'.format(os.path.join(rootpath, included_directory)), dirpath): - return True + yield filepath - return False - -def walk_over_directory(rootpath, extensions, show_detailed=False, include_dirs=None, show_progress=True): +def preprocess(all_files, show_detailed=False, show_progress=True, hipify_caffe2=False): """ - Recursively walk over the directory and call preprocessor on selected files. + Call preprocessor on selected files. Arguments) - extensions - A plist of file extensions ['cu', 'cuh', ..] - - include_dirs - Directories under the rootpath that should be included in the walk. - show_detailed - Show a detailed summary of the transpilation process. """ - # Default argument for excluded directories. - if include_dirs is None: - include_dirs = [] - # Compute the total number of files to be traversed. - total_files = 0 - for (dirpath, _dirnames, filenames) in os.walk(rootpath): - if inside_included_directories(dirpath, rootpath, include_dirs): - for filename in filenames: - total_files += filename_ends_with_extension(filename, extensions) - - current_file = 0 + total_count = len(all_files) + finished_count = 0 # Preprocessing statistics. stats = {"unsupported_calls": [], "kernel_launches": []} - # Begin traversing the files. - for (dirpath, _dirnames, filenames) in os.walk(rootpath, topdown=True): - # Check if file ends with a valid extensions - if not inside_included_directories(dirpath, rootpath, include_dirs): - continue - - for filename in filenames: - if filename_ends_with_extension(filename, extensions): - # Construct the file's full path - filepath = os.sep.join([dirpath, filename]) - - # Execute the preprocessor on the specified file. - preprocessor(filepath, stats) - - # Update the progress - if show_progress: - print(os.path.join(dirpath, filename)) - update_progress_bar(total_files, current_file) - - current_file += 1 + for filepath in all_files: + preprocessor(filepath, stats, hipify_caffe2) + # Update the progress + if show_progress: + print(filepath) + update_progress_bar(total_count, finished_count) + finished_count += 1 print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC) @@ -297,6 +275,41 @@ def compute_stats(stats): print("\nTotal number of replaced kernel launches: {0:d}".format(len(stats["kernel_launches"]))) +def add_dim3(kernel_string, cuda_kernel): + '''adds dim3() to the second and third arguments in the kernel launch''' + count = 0 + closure = 0 + kernel_string = kernel_string.replace("<<<", "").replace(">>>", "") + arg_locs = [{} for _ in range(2)] + arg_locs[count]['start'] = 0 + for ind, c in enumerate(kernel_string): + if count > 1: + break + if c == "(": + closure += 1 + elif c == ")": + closure -= 1 + elif (c == "," or ind == len(kernel_string) - 1) and closure == 0: + arg_locs[count]['end'] = ind + count += 1 + if count < 2: + arg_locs[count]['start'] = ind + 1 + + first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1] + second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']] + + first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ") + second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ") + + first_arg_dim3 = "dim3({})".format(first_arg_clean) + second_arg_dim3 = "dim3({})".format(second_arg_clean) + + first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3) + second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3) + cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3) + return cuda_kernel + + def processKernelLaunches(string, stats): """ Replace the CUDA style Kernel launches with the HIP style kernel launches.""" # Concat the namespace with the kernel names. (Find cleaner way of doing this later). @@ -396,12 +409,12 @@ def find_kernel_bounds(string): # Extract cuda kernel cuda_kernel = string[params[0]["start"]:parenthesis + 1] - + kernel_string = string[kernel['start']:kernel['end']] + cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel) # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size) num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")"))) - # Transform cuda kernel to hip kernel - hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel[0:-1].replace( + hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace( ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace(">>>", ", ") # Replace cuda kernel with hip kernel @@ -450,8 +463,9 @@ def disable_asserts(input_string): output_string = output_string.replace(input_string[start:p_end + 1], "") return output_string + def replace_forceinline(input_string): - """__forceinline__'d methods can cause 'symbol multiply defined' errors in HIP. + """__forceinline__'d methods can cause 'symbol multiply defined' errors in HIP. Adding 'static' to all such methods leads to compilation errors, so replacing '__forceinline__' with 'inline' as a workaround https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_faq.md#what-if-hip-generates-error-of-symbol-multiply-defined-only-on-amd-machine @@ -460,6 +474,7 @@ def replace_forceinline(input_string): output_string = re.sub("__forceinline__", "inline", output_string) return output_string + def replace_math_functions(input_string): """ FIXME: Temporarily replace std:: invocations of math functions with non-std:: versions to prevent linker errors NOTE: This can lead to correctness issues when running tests, since the correct version of the math function (exp/expf) might not get called. @@ -471,6 +486,7 @@ def replace_math_functions(input_string): output_string = re.sub("std::pow\(", "::pow(", output_string) return output_string + def disable_function(input_string, function, replace_style): """ Finds and disables a function in a particular file. @@ -610,11 +626,42 @@ def disable_function(input_string, function, replace_style): return output_string -def preprocessor(filepath, stats): +def get_hip_file_path(filepath, hipify_caffe2): + """ Returns the new name of the hipified file """ + if not hipify_caffe2: + return filepath + + dirpath, filename = os.path.split(filepath) + filename_without_ext, ext = os.path.splitext(filename) + + if 'gpu' in filename_without_ext: + filename_without_ext = filename_without_ext.replace('gpu', 'hip') + else: + filename_without_ext += '_hip' + + if ext == '.cu': + ext = '.cc' + + return os.path.join(dirpath, 'hip', filename_without_ext + ext) + + +def is_caffe2_gpu_file(filepath): + filename = os.path.basename(filepath) + _, ext = os.path.splitext(filename) + return 'gpu' in filename or ext in ['.cu', '.cuh'] + + +def preprocessor(filepath, stats, hipify_caffe2): """ Executes the CUDA -> HIP conversion on the specified file. """ - with openf(filepath, "r+") as fileobj: - output_source = fileobj.read() + fin_path = filepath + with open(fin_path, 'r') as fin: + output_source = fin.read() + fout_path = get_hip_file_path(filepath, hipify_caffe2) + if not os.path.exists(os.path.dirname(fout_path)): + os.makedirs(os.path.dirname(fout_path)) + + with open(fout_path, 'w') as fout: # Perform type, method, constant replacements for mapping in CUDA_TO_HIP_MAPPINGS: for cuda_type, value in mapping.items(): @@ -622,13 +669,22 @@ def preprocessor(filepath, stats): hip_type = value[0] meta_data = value[1:] + if constants.API_CAFFE2 in meta_data and not hipify_caffe2: + continue + if constants.API_RAND in meta_data and hipify_caffe2: + continue + if output_source.find(cuda_type) > -1: # Check if supported if constants.HIP_UNSUPPORTED in meta_data: stats["unsupported_calls"].append((cuda_type, filepath)) if cuda_type in output_source: - output_source = re.sub(r'\b({0})\b'.format(cuda_type), lambda x: hip_type, output_source) + if hipify_caffe2: + pattern = r'({0})'.format(re.escape(cuda_type)) + else: + pattern = r'(\b{0}\b)'.format(re.escape(cuda_type)) + output_source = re.sub(pattern, hip_type, output_source) # Perform Kernel Launch Replacements output_source = processKernelLaunches(output_source, stats) @@ -643,21 +699,14 @@ def preprocessor(filepath, stats): # Replace __forceinline__ with inline output_source = replace_forceinline(output_source) - # Overwrite file contents - fileobj.seek(0) - fileobj.write(output_source) - fileobj.truncate() - fileobj.flush() - - # Flush to disk - os.fsync(fileobj) + fout.write(output_source) def file_specific_replacement(filepath, search_string, replace_string, strict=False): with openf(filepath, "r+") as f: contents = f.read() if strict: - contents = re.sub(r'\b({0})\b'.format(search_string), lambda x: replace_string, contents) + contents = re.sub(r'\b({0})\b'.format(re.escape(search_string)), lambda x: replace_string, contents) else: contents = contents.replace(search_string, replace_string) f.seek(0) @@ -775,7 +824,7 @@ def disable_unsupported_function_call(function, input_string, replacement): output_string = input_string # Find all calls to the function - calls = re.finditer(r"\b{0}\b".format(function), input_string) + calls = re.finditer(r"\b{0}\b".format(re.escape(function)), input_string) # Do replacements for call in calls: @@ -847,7 +896,7 @@ def extract_arguments(start, string): closures["("] -= 1 elif string[current_position] == "<": closures["<"] += 1 - elif string[current_position] == ">" and string[current_position - 1] != "-": + elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0: closures["<"] -= 1 # Finished all arguments @@ -867,7 +916,7 @@ def extract_arguments(start, string): # Add static_cast to ensure that the type of kernel arguments matches that in the corresponding kernel definition -def add_static_casts(directory, extensions, KernelTemplateParams): +def add_static_casts(filepath, KernelTemplateParams): """Add static casts to kernel launches in order to keep launch argument types and kernel definition types matching. Example: @@ -884,73 +933,70 @@ def add_static_casts(directory, extensions, KernelTemplateParams): static_cast_types = ["int", "const int", "int64_t", "THCIndex_t *", "const int *", "ptrdiff_t", "long", "const int64_t*", "int64_t *", "double"] - # Add static_casts<> to all kernel launches. - for (dirpath, _dirnames, filenames) in os.walk(directory): - for filename in filenames: - if filename_ends_with_extension(filename, extensions): - filepath = os.sep.join([dirpath, filename]) - with openf(filepath, "r+") as fileobj: - input_source = fileobj.read() - new_output_source = input_source - for kernel in re.finditer("hipLaunchKernelGGL\(", input_source): - arguments = extract_arguments(kernel.end() - 1, input_source) - - # Check if we have templating + static_cast information - argument_strings = [input_source[arg["start"]:arg["end"]] for arg in arguments] - original_kernel_name_with_template = argument_strings[0].strip() - kernel_name = original_kernel_name_with_template.split("<")[0].strip() - ignore = ["upscale"] - if kernel_name in KernelTemplateParams and kernel_name not in ignore: - # Add template to the kernel - # Add static_casts to relevant arguments - kernel_name_with_template = KernelTemplateParams[kernel_name]["kernel_with_template"] - argument_types = KernelTemplateParams[kernel_name]["arg_types"] - - # The first 5 arguments are simply (function, number blocks, dimension blocks, shared memory, stream) - # old_kernel_launch_parameters - will contain the actual arguments to the function itself. - old_kernel_launch_parameters = input_source[arguments[5]["start"]:arguments[-1]["end"]] - new_kernel_launch_parameters = old_kernel_launch_parameters - - # full_old_kernel_launch - will contain the entire kernel launch closure. - full_old_kernel_launch = input_source[arguments[0]["start"]:arguments[-1]["end"]] - full_new_kernel_launch = full_old_kernel_launch - - kernel_params = argument_strings[5:] - for arg_idx, arg in enumerate(kernel_params): - if arg_idx in argument_types: - the_type = argument_types[arg_idx] - the_arg = arg.replace("\n", "").replace("\\", "").strip() - # Not all types have issues with the hipLaunchKernelGGL. - if the_type in static_cast_types: - static_argument = "static_cast<{0}>({1})".format(the_type, the_arg) - - def replace_arg(match): - return match.group(1) + static_argument + match.group(3) - # Update to static_cast, account for cases where argument is at start/end of string - new_kernel_launch_parameters = re.sub(r'(^|\W)({0})(\W|$)'.format( - re.escape(the_arg)), replace_arg, new_kernel_launch_parameters) - - # replace kernel arguments in full kernel launch arguments w/ static_cast ones - full_new_kernel_launch = full_new_kernel_launch.replace(old_kernel_launch_parameters, new_kernel_launch_parameters) - - # PyTorch Specific: Add template type - # Here the template value will be resolved from to . - if "THCUNN" in filepath.split("/") and "generic" not in filepath.split("/"): - kernel_name_with_template = kernel_name_with_template.replace("", "") - full_new_kernel_launch = re.sub(r'\b{0}\b'.format(original_kernel_name_with_template), - lambda x: kernel_name_with_template, full_new_kernel_launch) - - # Replace Launch - new_output_source = new_output_source.replace(full_old_kernel_launch, full_new_kernel_launch) - - # Overwrite file contents - fileobj.seek(0) - fileobj.write(new_output_source) - fileobj.truncate() - fileobj.flush() - - # Flush to disk - os.fsync(fileobj) + with openf(filepath, "r+") as fileobj: + input_source = fileobj.read() + new_output_source = input_source + for kernel in re.finditer("hipLaunchKernelGGL\(", input_source): + arguments = extract_arguments(kernel.end() - 1, input_source) + + # Check if we have templating + static_cast information + argument_strings = [input_source[arg["start"]:arg["end"]] for arg in arguments] + original_kernel_name_with_template = argument_strings[0].strip() + kernel_name = original_kernel_name_with_template.split("<")[0].strip() + ignore = ["upscale"] + if kernel_name in KernelTemplateParams and kernel_name not in ignore: + # Add template to the kernel + # Add static_casts to relevant arguments + kernel_name_with_template = KernelTemplateParams[kernel_name]["kernel_with_template"] + argument_types = KernelTemplateParams[kernel_name]["arg_types"] + + # The first 5 arguments are simply (function, number blocks, dimension blocks, shared memory, stream) + # old_kernel_launch_parameters - will contain the actual arguments to the function itself. + old_kernel_launch_parameters = input_source[arguments[5]["start"]:arguments[-1]["end"]] + new_kernel_launch_parameters = old_kernel_launch_parameters + + # full_old_kernel_launch - will contain the entire kernel launch closure. + full_old_kernel_launch = input_source[arguments[0]["start"]:arguments[-1]["end"]] + full_new_kernel_launch = full_old_kernel_launch + + kernel_params = argument_strings[5:] + for arg_idx, arg in enumerate(kernel_params): + if arg_idx in argument_types: + the_type = argument_types[arg_idx] + the_arg = arg.replace("\n", "").replace("\\", "").strip() + # Not all types have issues with the hipLaunchKernelGGL. + if the_type in static_cast_types: + static_argument = "static_cast<{0}>({1})".format(the_type, the_arg) + + def replace_arg(match): + return match.group(1) + static_argument + match.group(3) + # Update to static_cast, account for cases where argument is at start/end of string + new_kernel_launch_parameters = re.sub(r'(^|\W)({0})(\W|$)'.format( + re.escape(the_arg)), replace_arg, new_kernel_launch_parameters) + + # replace kernel arguments in full kernel launch arguments w/ static_cast ones + full_new_kernel_launch = full_new_kernel_launch.replace( + old_kernel_launch_parameters, new_kernel_launch_parameters) + + # PyTorch Specific: Add template type + # Here the template value will be resolved from to . + if "THCUNN" in filepath.split("/") and "generic" not in filepath.split("/"): + kernel_name_with_template = kernel_name_with_template.replace("", "") + + full_new_kernel_launch = re.sub(r'\b{0}\b'.format(re.escape(original_kernel_name_with_template)), + lambda x: kernel_name_with_template, full_new_kernel_launch) + + # Replace Launch + new_output_source = new_output_source.replace(full_old_kernel_launch, full_new_kernel_launch) + + # Overwrite file contents + fileobj.seek(0) + fileobj.write(new_output_source) + fileobj.truncate() + fileobj.flush() + + # Flush to disk + os.fsync(fileobj) def str2bool(v): @@ -990,7 +1036,7 @@ def main(): parser.add_argument( '--extensions', nargs='+', - default=["cu", "cuh", "c", "cpp", "h", "in", "hpp"], + default=[".cu", ".cuh", ".c", ".cpp", ".h", ".in", ".hpp"], help="The extensions for files to run the Hipify script over.", required=False) @@ -1002,10 +1048,10 @@ def main(): required=False) parser.add_argument( - '--include-dirs', + '--includes', nargs='+', default=[], - help="The directories under the root that should be included.", + help="The patterns of files that should be included.", required=False) parser.add_argument( @@ -1022,6 +1068,19 @@ def main(): help="Whether to automatically add static_casts to kernel arguments.", required=False) + parser.add_argument( + '--hipify_caffe2', + type=str2bool, + default=False, + help="Whether to hipify caffe2 source", + required=False) + + parser.add_argument( + '--ignores', + nargs='+', + default=[], + help="list of patterns to ignore for hipifying") + parser.add_argument( '--show-progress', type=str2bool, @@ -1037,33 +1096,14 @@ def main(): sys.exit(1) # If no output directory, provide a default one. - if args.output_directory is "": + if not args.output_directory: args.project_directory.rstrip("/") args.output_directory = args.project_directory + "_amd" - # Make sure output directory does not exist. - if not os.path.exists(args.output_directory): - print("The output folder already exists.") - sys.exit(2) - # Copy from project directory to output directory if not done already. if not os.path.exists(args.output_directory): shutil.copytree(args.project_directory, args.output_directory) - # Extract all of the kernel parameter and template type information. - if args.add_static_casts: - KernelTemplateParams = {} - for (dirpath, _dirnames, filenames) in os.walk(args.output_directory): - for filename in filenames: - if filename_ends_with_extension(filename, args.extensions) and inside_included_directories(dirpath, args.output_directory, args.include_dirs): - the_file = os.sep.join([dirpath, filename]) - - # Store param information inside KernelTemplateParams - get_kernel_template_params( - the_file, - KernelTemplateParams, - PYTORCH_TEMPLATE_MAP) - # Open YAML file with disable information. if args.yaml_settings != "": with openf(args.yaml_settings, "r") as f: @@ -1141,7 +1181,7 @@ def main(): # Disable Constants w\ Boundary. for const in constants: - txt = re.sub(r"\b{0}\b".format(const), constants[const], txt) + txt = re.sub(r"\b{0}\b".format(re.escape(const)), constants[const], txt) # Disable Constants for s_const in s_constants: @@ -1152,17 +1192,28 @@ def main(): f.write(txt) f.truncate() + all_files = list(matched_files_iter(args.output_directory, includes=args.includes, + ignores=args.ignores, extensions=args.extensions, hipify_caffe2=args.hipify_caffe2)) + # Start Preprocessor - walk_over_directory( - args.output_directory, - extensions=args.extensions, + preprocess( + all_files, show_detailed=args.show_detailed, - include_dirs=args.include_dirs, - show_progress=args.show_progress) + show_progress=args.show_progress, + hipify_caffe2=args.hipify_caffe2) + # Extract all of the kernel parameter and template type information. if args.add_static_casts: + KernelTemplateParams = {} + for filepath in all_files: + get_kernel_template_params( + filepath, + KernelTemplateParams, + CAFFE2_TEMPLATE_MAP if args.hipify_caffe2 else PYTORCH_TEMPLATE_MAP) + # Execute the Clang Tool to Automatically add static casts - add_static_casts(args.output_directory, args.extensions, KernelTemplateParams) + for filepath in all_files: + add_static_casts(get_hip_file_path(filepath, hipify_caffe2=args.hipify_caffe2), KernelTemplateParams) if __name__ == '__main__':