Skip to content

Fix Dynamic Shared Memory. #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tools/amd_build/pyHIPIFY/hipify-python.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,22 @@ def replace_math_functions(input_string):
return output_string


def replace_extern_shared(input_string):
"""Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#__shared__
Example:
"extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
"extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
"""
output_string = input_string
output_string = re.sub(
r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;",
lambda inp: "HIP_DYNAMIC_SHARED({0} {1}, {2})".format(
inp.group(1) or "", inp.group(2), inp.group(3)), output_string)

return output_string


def disable_function(input_string, function, replace_style):
""" Finds and disables a function in a particular file.

Expand Down Expand Up @@ -699,6 +715,9 @@ def preprocessor(filepath, stats, hipify_caffe2):
# Replace __forceinline__ with inline
output_source = replace_forceinline(output_source)

# Replace the extern __shared__
output_source = replace_extern_shared(output_source)

fout.write(output_source)


Expand Down