Skip to content

Commit a333147

Browse files
authored
Merge pull request #15 from iotamudelta/master
next round of fixes to address comments
2 parents dc103d4 + 9eb80c9 commit a333147

File tree

2 files changed

+53
-40
lines changed

2 files changed

+53
-40
lines changed

tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,8 @@
574574
"CU_GRAPHICS_MAP_RESOURCE_FLAGS_NONE": ("hipGraphicsMapFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED),
575575
"CU_GRAPHICS_MAP_RESOURCE_FLAGS_READ_ONLY": ("hipGraphicsMapFlagsReadOnly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED),
576576
"CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITE_DISCARD": ("hipGraphicsMapFlagsWriteDiscard", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED),
577-
"CU_GRAPHICS_MAP_RESOURCE_FLAGS_NONE": ("hipGraphicsRegisterFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED),
578-
"CU_GRAPHICS_MAP_RESOURCE_FLAGS_READ_ONLY": ("hipGraphicsRegisterFlagsReadOnly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED),
577+
"CU_GRAPHICS_REGISTER_FLAGS_NONE": ("hipGraphicsRegisterFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED),
578+
"CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY": ("hipGraphicsRegisterFlagsReadOnly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED),
579579
"CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD": ("hipGraphicsRegisterFlagsWriteDiscard", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED),
580580
"CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST": ("hipGraphicsRegisterFlagsSurfaceLoadStore", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED),
581581
"CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER": ("hipGraphicsRegisterFlagsTextureGather", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED),

tools/amd_build/pyHIPIFY/hipify-python.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import yaml
3434

3535
from functools import reduce
36+
from enum import Enum
3637
from cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
3738

3839

@@ -56,6 +57,27 @@ class bcolors:
5657
UNDERLINE = '\033[4m'
5758

5859

60+
class disablefuncmode(Enum):
61+
""" How to disable functions
62+
0 - Remove the function entirely (includes the signature).
63+
1 - Stub the function and return an empty object based off the type.
64+
2 - Add !defined(__HIP_PLATFORM_HCC__) preprocessors around the function.
65+
This macro is defined by HIP if the compiler used is hcc.
66+
3 - Add !defined(__HIP_DEVICE_COMPILE__) preprocessors around the function.
67+
This macro is defined by HIP if either hcc or nvcc are used in the device path.
68+
4 - Stub the function and throw an exception at runtime.
69+
5 - Stub the function and throw an assert(0).
70+
6 - Stub the function and keep an empty body.
71+
"""
72+
REMOVE = 0
73+
STUB = 1
74+
HCC_MACRO = 2
75+
DEVICE_MACRO = 3
76+
EXCEPTION = 4
77+
ASSERT = 5
78+
EMPTYBODY = 6
79+
80+
5981
def update_progress_bar(total, progress):
6082
"""
6183
Displays and updates a console progress bar.
@@ -72,10 +94,10 @@ def update_progress_bar(total, progress):
7294
status)
7395

7496
# Send the progress to stdout.
75-
sys.stdout.write(text)
97+
sys.stderr.write(text)
7698

7799
# Send the buffered text to stdout!
78-
sys.stdout.flush()
100+
sys.stderr.flush()
79101

80102

81103
def filename_ends_with_extension(filename, extensions):
@@ -159,7 +181,7 @@ def compute_stats(stats):
159181
def processKernelLaunches(string, stats):
160182
""" Replace the CUDA style Kernel launches with the HIP style kernel launches."""
161183
# Concat the namespace with the kernel names. (Find cleaner way of doing this later).
162-
string = re.sub(r'([ ]+)(detail+)::[ ]+\\\n[ ]+', lambda inp: "%s%s::" % (inp.group(1), inp.group(2)), string)
184+
string = re.sub(r'([ ]+)(detail?)::[ ]+\\\n[ ]+', lambda inp: "%s%s::" % (inp.group(1), inp.group(2)), string)
163185

164186
def grab_method_and_template(in_kernel):
165187
# The positions for relevant kernel components.
@@ -232,6 +254,8 @@ def find_kernel_bounds(string):
232254

233255
# Get kernel ending position (adjust end point past the >>>)
234256
kernel_end = string.find(">>>", kernel_start) + 3
257+
if kernel_end <= 0:
258+
raise InputError("no kernel end found")
235259

236260
# Add to list of traversed kernels
237261
kernel_positions.append({"start": kernel_start, "end": kernel_end,
@@ -248,11 +272,11 @@ def find_kernel_bounds(string):
248272
# Get kernel components
249273
params = grab_method_and_template(kernel)
250274

251-
# Find paranthesis after kernel launch
252-
paranthesis = string.find("(", kernel["end"])
275+
# Find parenthesis after kernel launch
276+
parenthesis = string.find("(", kernel["end"])
253277

254278
# Extract cuda kernel
255-
cuda_kernel = string[params[0]["start"]:paranthesis + 1]
279+
cuda_kernel = string[params[0]["start"]:parenthesis + 1]
256280

257281
# Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size)
258282
num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")")))
@@ -270,24 +294,24 @@ def find_kernel_bounds(string):
270294
return output_string
271295

272296

273-
def find_paranthesis_end(input_string, start):
274-
inside_paranthesis = False
275-
parans = 0
297+
def find_parenthesis_end(input_string, start):
298+
inside_parenthesis = False
299+
parens = 0
276300
pos = start
277301
p_start, p_end = -1, -1
278302

279303
while pos < len(input_string):
280304
if input_string[pos] == "(":
281-
if inside_paranthesis is False:
282-
inside_paranthesis = True
283-
parans = 1
305+
if inside_parenthesis is False:
306+
inside_parenthesis = True
307+
parens = 1
284308
p_start = pos
285309
else:
286-
parans += 1
287-
elif input_string[pos] == ")" and inside_paranthesis:
288-
parans -= 1
310+
parens += 1
311+
elif input_string[pos] == ")" and inside_parenthesis:
312+
parens -= 1
289313

290-
if parans == 0:
314+
if parens == 0:
291315
p_end = pos
292316
return p_start, p_end
293317

@@ -302,7 +326,7 @@ def disable_asserts(input_string):
302326
output_string = input_string
303327
asserts = list(re.finditer(r"\bassert[ ]*\(", input_string))
304328
for assert_item in asserts:
305-
p_start, p_end = find_paranthesis_end(input_string, assert_item.end() - 1)
329+
p_start, p_end = find_parenthesis_end(input_string, assert_item.end() - 1)
306330
start = assert_item.start()
307331
output_string = output_string.replace(input_string[start:p_end + 1], "")
308332
return output_string
@@ -321,13 +345,6 @@ def disable_function(input_string, function, replace_style):
321345
e.g. "overlappingIndices"
322346
323347
replace_style - The style to use when stubbing functions.
324-
0 - Remove the function entirely (includes the signature).
325-
1 - Stub the function and return an empty object based off the type.
326-
2 - Add !defined(__HIP_PLATFORM_HCC__) preprocessors around the function.
327-
3 - Add !defined(__HIP_DEVICE_COMPILE__) preprocessors around the function.
328-
4 - Stub the function and throw an exception at runtime.
329-
5 - Stub the function and throw an assert(0).
330-
6 - Stub the function and keep an empty body.
331348
"""
332349
# void (*)(hcrngStateMtgp32 *, int, float *, double, double)
333350
info = {
@@ -406,11 +423,11 @@ def disable_function(input_string, function, replace_style):
406423
function_body = input_string[info["function_start"]:info["function_end"] + 1]
407424

408425
# Remove the entire function body
409-
if replace_style == 0:
426+
if replace_style == disablefuncmode.REMOVE:
410427
output_string = input_string.replace(function_body, "")
411428

412429
# Stub the function based off its return type.
413-
elif replace_style == 1:
430+
elif replace_style == disablefuncmode.STUB:
414431
# void return type
415432
if func_info["return_type"] == "void" or func_info["return_type"] == "static void":
416433
stub = "%s{\n}" % (function_string)
@@ -423,32 +440,32 @@ def disable_function(input_string, function, replace_style):
423440
output_string = input_string.replace(function_body, stub)
424441

425442
# Add HIP Preprocessors.
426-
elif replace_style == 2:
443+
elif replace_style == disablefuncmode.HCC_MACRO:
427444
output_string = input_string.replace(
428445
function_body,
429446
"#if !defined(__HIP_PLATFORM_HCC__)\n%s\n#endif" % function_body)
430447

431448
# Add HIP Preprocessors.
432-
elif replace_style == 3:
449+
elif replace_style == disablefuncmode.DEVICE_MACRO:
433450
output_string = input_string.replace(
434451
function_body,
435452
"#if !defined(__HIP_DEVICE_COMPILE__)\n%s\n#endif" % function_body)
436453

437454
# Throw an exception at runtime.
438-
elif replace_style == 4:
455+
elif replace_style == disablefuncmode.EXCEPTION:
439456
stub = "%s{\n%s;\n}" % (
440457
function_string,
441458
'throw std::runtime_error("The function %s is not implemented.")' %
442459
function_string.replace("\n", " "))
443460
output_string = input_string.replace(function_body, stub)
444461

445-
elif replace_style == 5:
462+
elif replace_style == disablefuncmode.ASSERT:
446463
stub = "%s{\n%s;\n}" % (
447464
function_string,
448465
'assert(0)')
449466
output_string = input_string.replace(function_body, stub)
450467

451-
elif replace_style == 6:
468+
elif replace_style == disablefuncmode.EMPTY:
452469
stub = "%s{\n;\n}" % (function_string)
453470
output_string = input_string.replace(function_body, stub)
454471
return output_string
@@ -501,8 +518,6 @@ def file_specific_replacement(filepath, search_string, replace_string, strict=Fa
501518
f.seek(0)
502519
f.write(contents)
503520
f.truncate()
504-
f.flush()
505-
os.fsync(f)
506521

507522

508523
def file_add_header(filepath, header):
@@ -514,8 +529,6 @@ def file_add_header(filepath, header):
514529
f.seek(0)
515530
f.write(contents)
516531
f.truncate()
517-
f.flush()
518-
os.fsync(f)
519532

520533

521534
def fix_static_global_kernels(in_txt):
@@ -875,15 +888,15 @@ def main():
875888
txt = f.read()
876889
for func in functions:
877890
# TODO - Find fix assertions in HIP for device code.
878-
txt = disable_function(txt, func, 5)
891+
txt = disable_function(txt, func, disablefuncmode.ASSERT)
879892

880893
for func in non_hip_functions:
881894
# Disable this function on HIP stack
882-
txt = disable_function(txt, func, 2)
895+
txt = disable_function(txt, func, disablefuncmode.HCC_MACRO)
883896

884897
for func in not_on_device_functions:
885898
# Disable this function when compiling on Device
886-
txt = disable_function(txt, func, 3)
899+
txt = disable_function(txt, func, disablefuncmode.DEVICE_MACRO)
887900

888901
f.seek(0)
889902
f.write(txt)

0 commit comments

Comments
 (0)