Skip to content

Fix static_cast logic in hipify script for better coverage #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 9, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 54 additions & 41 deletions tools/amd_build/pyHIPIFY/hipify-python.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def filename_ends_with_extension(filename, extensions):

def inside_included_directories(dirpath, rootpath, include_dirs):
"""Helper method to see if filename within included directories"""
return reduce(lambda result, included_directory: re.match(r'(%s)\b' % os.path.join(rootpath, included_directory), dirpath) or result, include_dirs, None)
return reduce(lambda result, included_directory: re.match(r'{0}\b'.format(os.path.join(rootpath, included_directory)), dirpath) or result, include_dirs, None)


def walk_over_directory(rootpath, extensions, show_detailed=False, include_dirs=None):
Expand Down Expand Up @@ -169,19 +169,19 @@ def compute_stats(stats):
unsupported_calls = set(cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"])

# Print the number of unsupported calls
print("Total number of unsupported CUDA function calls: %d" % (len(unsupported_calls)))
print("Total number of unsupported CUDA function calls: {0:d}".format(len(unsupported_calls)))

# Print the list of unsupported calls
print(", ".join(unsupported_calls))

# Print the number of kernel launches
print("\nTotal number of replaced kernel launches: %d" % (len(stats["kernel_launches"])))
print("\nTotal number of replaced kernel launches: {0:d}".format(len(stats["kernel_launches"])))


def processKernelLaunches(string, stats):
""" Replace the CUDA style Kernel launches with the HIP style kernel launches."""
# Concat the namespace with the kernel names. (Find cleaner way of doing this later).
string = re.sub(r'([ ]+)(detail?)::[ ]+\\\n[ ]+', lambda inp: "%s%s::" % (inp.group(1), inp.group(2)), string)
string = re.sub(r'([ ]+)(detail?)::[ ]+\\\n[ ]+', lambda inp: "{0}{1}::".format(inp.group(1), inp.group(2)), string)

def grab_method_and_template(in_kernel):
# The positions for relevant kernel components.
Expand Down Expand Up @@ -368,7 +368,7 @@ def disable_function(input_string, function, replace_style):
}

# Create function string to search for
function_string = "%s%s%s" % (
function_string = "{0}{1}{2}".format(
func_info["return_type"],
func_info["function_name"],
func_info["function_args"]
Expand All @@ -378,8 +378,7 @@ def disable_function(input_string, function, replace_style):
info["function_start"] = input_string.find(function_string)
else:
# Automatically detect signature.
the_match = re.search(r"(((.*) (\*)?)(%s)(\([^{)]*\)))\s*{" %
(function.replace("(", "\(").replace(")", "\)")), input_string)
the_match = re.search(r"(((.*) (\*)?)({0})(\([^{{)]*\)))\s*{{".format(function.replace("(", "\(").replace(")", "\)")), input_string)
if the_match is None:
return input_string

Expand Down Expand Up @@ -430,43 +429,43 @@ def disable_function(input_string, function, replace_style):
elif replace_style == disablefuncmode.STUB:
# void return type
if func_info["return_type"] == "void" or func_info["return_type"] == "static void":
stub = "%s{\n}" % (function_string)
stub = "{0}{{\n}}".format(function_string)
# pointer return type
elif "*" in func_info["return_type"]:
stub = "%s{\nreturn %s;\n}" % (function_string, "NULL") # nullptr
stub = "{0}{{\nreturn {1};\n}}".format(function_string, "NULL") # nullptr
else:
stub = "%s{\n%s stub_var;\nreturn stub_var;\n}" % (function_string, func_info["return_type"])
stub = "{0}{{\n{1} stub_var;\nreturn stub_var;\n}}".format(function_string, func_info["return_type"])

output_string = input_string.replace(function_body, stub)

# Add HIP Preprocessors.
elif replace_style == disablefuncmode.HCC_MACRO:
output_string = input_string.replace(
function_body,
"#if !defined(__HIP_PLATFORM_HCC__)\n%s\n#endif" % function_body)
"#if !defined(__HIP_PLATFORM_HCC__)\n{0}\n#endif".format(function_body))

# Add HIP Preprocessors.
elif replace_style == disablefuncmode.DEVICE_MACRO:
output_string = input_string.replace(
function_body,
"#if !defined(__HIP_DEVICE_COMPILE__)\n%s\n#endif" % function_body)
"#if !defined(__HIP_DEVICE_COMPILE__)\n{0}\n#endif".format(function_body))

# Throw an exception at runtime.
elif replace_style == disablefuncmode.EXCEPTION:
stub = "%s{\n%s;\n}" % (
stub = "{0}{{\n{1};\n}}".format(
function_string,
'throw std::runtime_error("The function %s is not implemented.")' %
function_string.replace("\n", " "))
'throw std::runtime_error("The function {0} is not implemented.")'.format(
function_string.replace("\n", " ")))
output_string = input_string.replace(function_body, stub)

elif replace_style == disablefuncmode.ASSERT:
stub = "%s{\n%s;\n}" % (
stub = "{0}{{\n{1};\n}}".format(
function_string,
'assert(0)')
output_string = input_string.replace(function_body, stub)

elif replace_style == disablefuncmode.EMPTY:
stub = "%s{\n;\n}" % (function_string)
stub = "{0}{{\n;\n}}".format(function_string)
output_string = input_string.replace(function_body, stub)
return output_string

Expand All @@ -489,7 +488,7 @@ def preprocessor(filepath, stats):
stats["unsupported_calls"].append((cuda_type, filepath))

if cuda_type in output_source:
output_source = re.sub(r'\b(%s)\b' % cuda_type, lambda x: hip_type, output_source)
output_source = re.sub(r'\b({0})\b'.format(cuda_type), lambda x: hip_type, output_source)

# Perform Kernel Launch Replacements
output_source = processKernelLaunches(output_source, stats)
Expand All @@ -512,7 +511,7 @@ def file_specific_replacement(filepath, search_string, replace_string, strict=Fa
with openf(filepath, "r+") as f:
contents = f.read()
if strict:
contents = re.sub(r'\b(%s)\b' % search_string, lambda x: replace_string, contents)
contents = re.sub(r'\b({0})\b'.format(search_string), lambda x: replace_string, contents)
else:
contents = contents.replace(search_string, replace_string)
f.seek(0)
Expand All @@ -524,8 +523,8 @@ def file_add_header(filepath, header):
with openf(filepath, "r+") as f:
contents = f.read()
if header[0] != "<" and header[-1] != ">":
header = '"%s"' % header
contents = ('#include %s \n' % header) + contents
header = '"{0}"'.format(header)
contents = ('#include {0} \n'.format(header)) + contents
f.seek(0)
f.write(contents)
f.truncate()
Expand Down Expand Up @@ -595,7 +594,7 @@ def get_kernel_template_params(the_file, KernelDictionary):
break
if len(template_arguments) == 1 and template_arguments[0].strip() in ["Dtype", "T"]:
# Updates kernel
kernel_with_template = "%s<real>" % (kernel_name)
kernel_with_template = "{0}<real>".format(kernel_name)
else:
kernel_with_template = kernel_name
formatted_args = {}
Expand All @@ -613,10 +612,10 @@ def get_kernel_template_params(the_file, KernelDictionary):
kernel_params = kernel.group(2).split(",")[1:]

if kernel_gen_type == 1:
kernel_args = {1: "int", 2: "%s *" % kernel_params[0], 3: kernel_params[1]}
kernel_args = {1: "int", 2: "{0} *".format(kernel_params[0]), 3: kernel_params[1]}

if kernel_gen_type == 2:
kernel_args = {1: "int", 2: "%s *" % kernel_params[0], 3: kernel_params[1], 4: kernel_params[2]}
kernel_args = {1: "int", 2: "{0} *".format(kernel_params[0]), 3: kernel_params[1], 4: kernel_params[2]}

# Argument at position 1 should be int
KernelDictionary[kernel_name] = {"kernel_with_template": kernel_name, "arg_types": kernel_args}
Expand All @@ -628,7 +627,7 @@ def disable_unsupported_function_call(function, input_string, replacement):
output_string = input_string

# Find all calls to the function
calls = re.finditer(r"\b%s\b" % function, input_string)
calls = re.finditer(r"\b{0}\b".format(function), input_string)

# Do replacements
for call in calls:
Expand Down Expand Up @@ -666,15 +665,20 @@ def disable_module(input_file):
last = list(re.finditer(r"#include .*\n", txt))[-1]
end = last.end()

disabled = "%s#if !defined(__HIP_PLATFORM_HCC__)\n%s\n#endif" % (txt[0:end], txt[end:])
disabled = "{0}#if !defined(__HIP_PLATFORM_HCC__)\n{1}\n#endif".format(txt[0:end], txt[end:])

f.seek(0)
f.write(disabled)
f.truncate()


def extract_arguments(start, string):
""" Return the list of arguments in the upcoming function parameter closure"""
""" Return the list of arguments in the upcoming function parameter closure
This function needs a string that contains function arguments fully encapsulated within opening and closing parantheses.
Eg:
string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
arguments (output): '[{'start': 1, 'end': 7}, {'start': 8, 'end': 16}, {'start': 17, 'end': 19}, {'start': 20, 'end': 53}]'
"""

arguments = []
closures = {
"<": 0,
Expand Down Expand Up @@ -709,9 +713,18 @@ def extract_arguments(start, string):

return arguments

# Add static_cast to ensure that the type of kernel arguments matches that in the corresponding kernel definition

def add_static_casts(directory, extensions, KernelTemplateParams):
"""Added necessary static casts to kernel launches due to issue in HIP"""
"""Added necessary static casts to kernel launches to match kernel argument type to corresponding kernel definition
Eg.
old_kernel_launch: ' createBatchGemmBuffer, grid, block, 0, THCState_getCurrentStream(state),
(const real**)d_result, THCTensor_(data)(state, ra__),
ra__->stride[0], num_batches'
new_kernel_launch: ' createBatchGemmBuffer, grid, block, 0, THCState_getCurrentStream(state),
(const real**)d_result, THCTensor_(data)(state, ra__),
static_cast<int64_t>(ra__->stride[0]), static_cast<int64_t>(num_batches)'
"""
# Add static_casts<> to all kernel launches.
for (dirpath, _dirnames, filenames) in os.walk(directory):
for filename in filenames:
Expand All @@ -726,7 +739,8 @@ def add_static_casts(directory, extensions, KernelTemplateParams):

# Check if we have templating + static_cast information
argument_strings = [input_source[arg["start"]:arg["end"]] for arg in arguments]
kernel_name = argument_strings[0].strip()
original_kernel_name_with_template = argument_strings[0].strip()
kernel_name = original_kernel_name_with_template.split("<")[0].strip()
ignore = ["upscale"]
if kernel_name in KernelTemplateParams and kernel_name not in ignore:
# Add template to the kernel
Expand All @@ -740,21 +754,20 @@ def add_static_casts(directory, extensions, KernelTemplateParams):
kernel_params = argument_strings[5:]
for arg_idx, arg in enumerate(kernel_params):
if arg_idx in argument_types:
arg = kernel_params[arg_idx].strip()
the_type = argument_types[arg_idx]
the_arg = arg.replace("\n", "").strip()
the_arg = arg.replace("\n", "").replace("\\", "").strip()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is replace("\\", "") added here?

Copy link
Collaborator Author

@jithunnair-amd jithunnair-amd Jul 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are instances where the kernel invocation is part of a preprocessor multiline macro. That's where the "\\" comes from (for the literal backslash used in a multiline macro).

if the_type in ["int", "const int", "int64_t", "THCIndex_t *", "const int *", "ptrdiff_t", "long", "const int64_t*", "int64_t *", "double"]:
static_argument = "static_cast<%s>(%s)" % (the_type, the_arg)
static_argument = arg.replace(the_arg, static_argument)

# Update to static_cast
new_kernel_launch = re.sub(r'\b(%s)\b' %
arg, lambda x: static_argument, new_kernel_launch)
static_argument = "static_cast<{0}>({1})".format(the_type, the_arg)

def replace_arg(match):
return match.group(1) + static_argument + match.group(3)
# Update to static_cast, account for cases where argument is at start/end of string
new_kernel_launch = re.sub(r'(^|\W)({0})(\W|$)'.format(re.escape(the_arg)), replace_arg, new_kernel_launch)

# Add template type
if "THCUNN" in filepath.split("/") and "generic" not in filepath.split("/"):
kernel_name_with_template = kernel_name_with_template.replace("<real>", "<Dtype>")
new_kernel_launch = re.sub(r'\b(%s)\b' % kernel_name,
new_kernel_launch = re.sub(r'\b{0}\b'.format(original_kernel_name_with_template),
lambda x: kernel_name_with_template, new_kernel_launch)

# Replace Launch
Expand Down Expand Up @@ -926,7 +939,7 @@ def main():
s_constants = []

if not os.path.exists(filepath):
print("\n" + bcolors.WARNING + "YAML Warning: File %s does not exist." % filepath + bcolors.ENDC)
print("\n" + bcolors.WARNING + "YAML Warning: File {0} does not exist.".format(filepath) + bcolors.ENDC)
continue

with openf(filepath, "r+") as f:
Expand All @@ -938,7 +951,7 @@ def main():

# Disable Constants w\ Boundary.
for const in constants:
txt = re.sub(r"\b%s\b" % const, constants[const], txt)
txt = re.sub(r"\b{0}\b".format(const), constants[const], txt)

# Disable Constants
for s_const in s_constants:
Expand Down