33
33
import yaml
34
34
35
35
from functools import reduce
36
+ from enum import Enum
36
37
from cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
37
38
38
39
@@ -56,6 +57,27 @@ class bcolors:
56
57
UNDERLINE = '\033 [4m'
57
58
58
59
60
+ class disablefuncmode (Enum ):
61
+ """ How to disable functions
62
+ 0 - Remove the function entirely (includes the signature).
63
+ 1 - Stub the function and return an empty object based off the type.
64
+ 2 - Add !defined(__HIP_PLATFORM_HCC__) preprocessors around the function.
65
+ This macro is defined by HIP if the compiler used is hcc.
66
+ 3 - Add !defined(__HIP_DEVICE_COMPILE__) preprocessors around the function.
67
+ This macro is defined by HIP if either hcc or nvcc are used in the device path.
68
+ 4 - Stub the function and throw an exception at runtime.
69
+ 5 - Stub the function and throw an assert(0).
70
+ 6 - Stub the function and keep an empty body.
71
+ """
72
+ REMOVE = 0
73
+ STUB = 1
74
+ HCC_MACRO = 2
75
+ DEVICE_MACRO = 3
76
+ EXCEPTION = 4
77
+ ASSERT = 5
78
+ EMPTYBODY = 6
79
+
80
+
59
81
def update_progress_bar (total , progress ):
60
82
"""
61
83
Displays and updates a console progress bar.
@@ -72,10 +94,10 @@ def update_progress_bar(total, progress):
72
94
status )
73
95
74
96
# Send the progress to stdout.
75
- sys .stdout .write (text )
97
+ sys .stderr .write (text )
76
98
77
99
# Send the buffered text to stdout!
78
- sys .stdout .flush ()
100
+ sys .stderr .flush ()
79
101
80
102
81
103
def filename_ends_with_extension (filename , extensions ):
@@ -159,7 +181,7 @@ def compute_stats(stats):
159
181
def processKernelLaunches (string , stats ):
160
182
""" Replace the CUDA style Kernel launches with the HIP style kernel launches."""
161
183
# Concat the namespace with the kernel names. (Find cleaner way of doing this later).
162
- string = re .sub (r'([ ]+)(detail+ )::[ ]+\\\n[ ]+' , lambda inp : "%s%s::" % (inp .group (1 ), inp .group (2 )), string )
184
+ string = re .sub (r'([ ]+)(detail? )::[ ]+\\\n[ ]+' , lambda inp : "%s%s::" % (inp .group (1 ), inp .group (2 )), string )
163
185
164
186
def grab_method_and_template (in_kernel ):
165
187
# The positions for relevant kernel components.
@@ -232,6 +254,8 @@ def find_kernel_bounds(string):
232
254
233
255
# Get kernel ending position (adjust end point past the >>>)
234
256
kernel_end = string .find (">>>" , kernel_start ) + 3
257
+ if kernel_end <= 0 :
258
+ raise InputError ("no kernel end found" )
235
259
236
260
# Add to list of traversed kernels
237
261
kernel_positions .append ({"start" : kernel_start , "end" : kernel_end ,
@@ -248,11 +272,11 @@ def find_kernel_bounds(string):
248
272
# Get kernel components
249
273
params = grab_method_and_template (kernel )
250
274
251
- # Find paranthesis after kernel launch
252
- paranthesis = string .find ("(" , kernel ["end" ])
275
+ # Find parenthesis after kernel launch
276
+ parenthesis = string .find ("(" , kernel ["end" ])
253
277
254
278
# Extract cuda kernel
255
- cuda_kernel = string [params [0 ]["start" ]:paranthesis + 1 ]
279
+ cuda_kernel = string [params [0 ]["start" ]:parenthesis + 1 ]
256
280
257
281
# Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size)
258
282
num_klp = len (extract_arguments (0 , kernel ["group" ].replace ("<<<" , "(" ).replace (">>>" , ")" )))
@@ -270,24 +294,24 @@ def find_kernel_bounds(string):
270
294
return output_string
271
295
272
296
273
- def find_paranthesis_end (input_string , start ):
274
- inside_paranthesis = False
275
- parans = 0
297
+ def find_parenthesis_end (input_string , start ):
298
+ inside_parenthesis = False
299
+ parens = 0
276
300
pos = start
277
301
p_start , p_end = - 1 , - 1
278
302
279
303
while pos < len (input_string ):
280
304
if input_string [pos ] == "(" :
281
- if inside_paranthesis is False :
282
- inside_paranthesis = True
283
- parans = 1
305
+ if inside_parenthesis is False :
306
+ inside_parenthesis = True
307
+ parens = 1
284
308
p_start = pos
285
309
else :
286
- parans += 1
287
- elif input_string [pos ] == ")" and inside_paranthesis :
288
- parans -= 1
310
+ parens += 1
311
+ elif input_string [pos ] == ")" and inside_parenthesis :
312
+ parens -= 1
289
313
290
- if parans == 0 :
314
+ if parens == 0 :
291
315
p_end = pos
292
316
return p_start , p_end
293
317
@@ -302,7 +326,7 @@ def disable_asserts(input_string):
302
326
output_string = input_string
303
327
asserts = list (re .finditer (r"\bassert[ ]*\(" , input_string ))
304
328
for assert_item in asserts :
305
- p_start , p_end = find_paranthesis_end (input_string , assert_item .end () - 1 )
329
+ p_start , p_end = find_parenthesis_end (input_string , assert_item .end () - 1 )
306
330
start = assert_item .start ()
307
331
output_string = output_string .replace (input_string [start :p_end + 1 ], "" )
308
332
return output_string
@@ -321,13 +345,6 @@ def disable_function(input_string, function, replace_style):
321
345
e.g. "overlappingIndices"
322
346
323
347
replace_style - The style to use when stubbing functions.
324
- 0 - Remove the function entirely (includes the signature).
325
- 1 - Stub the function and return an empty object based off the type.
326
- 2 - Add !defined(__HIP_PLATFORM_HCC__) preprocessors around the function.
327
- 3 - Add !defined(__HIP_DEVICE_COMPILE__) preprocessors around the function.
328
- 4 - Stub the function and throw an exception at runtime.
329
- 5 - Stub the function and throw an assert(0).
330
- 6 - Stub the function and keep an empty body.
331
348
"""
332
349
# void (*)(hcrngStateMtgp32 *, int, float *, double, double)
333
350
info = {
@@ -406,11 +423,11 @@ def disable_function(input_string, function, replace_style):
406
423
function_body = input_string [info ["function_start" ]:info ["function_end" ] + 1 ]
407
424
408
425
# Remove the entire function body
409
- if replace_style == 0 :
426
+ if replace_style == disablefuncmode . REMOVE :
410
427
output_string = input_string .replace (function_body , "" )
411
428
412
429
# Stub the function based off its return type.
413
- elif replace_style == 1 :
430
+ elif replace_style == disablefuncmode . STUB :
414
431
# void return type
415
432
if func_info ["return_type" ] == "void" or func_info ["return_type" ] == "static void" :
416
433
stub = "%s{\n }" % (function_string )
@@ -423,32 +440,32 @@ def disable_function(input_string, function, replace_style):
423
440
output_string = input_string .replace (function_body , stub )
424
441
425
442
# Add HIP Preprocessors.
426
- elif replace_style == 2 :
443
+ elif replace_style == disablefuncmode . HCC_MACRO :
427
444
output_string = input_string .replace (
428
445
function_body ,
429
446
"#if !defined(__HIP_PLATFORM_HCC__)\n %s\n #endif" % function_body )
430
447
431
448
# Add HIP Preprocessors.
432
- elif replace_style == 3 :
449
+ elif replace_style == disablefuncmode . DEVICE_MACRO :
433
450
output_string = input_string .replace (
434
451
function_body ,
435
452
"#if !defined(__HIP_DEVICE_COMPILE__)\n %s\n #endif" % function_body )
436
453
437
454
# Throw an exception at runtime.
438
- elif replace_style == 4 :
455
+ elif replace_style == disablefuncmode . EXCEPTION :
439
456
stub = "%s{\n %s;\n }" % (
440
457
function_string ,
441
458
'throw std::runtime_error("The function %s is not implemented.")' %
442
459
function_string .replace ("\n " , " " ))
443
460
output_string = input_string .replace (function_body , stub )
444
461
445
- elif replace_style == 5 :
462
+ elif replace_style == disablefuncmode . ASSERT :
446
463
stub = "%s{\n %s;\n }" % (
447
464
function_string ,
448
465
'assert(0)' )
449
466
output_string = input_string .replace (function_body , stub )
450
467
451
- elif replace_style == 6 :
468
+ elif replace_style == disablefuncmode . EMPTY :
452
469
stub = "%s{\n ;\n }" % (function_string )
453
470
output_string = input_string .replace (function_body , stub )
454
471
return output_string
@@ -501,8 +518,6 @@ def file_specific_replacement(filepath, search_string, replace_string, strict=Fa
501
518
f .seek (0 )
502
519
f .write (contents )
503
520
f .truncate ()
504
- f .flush ()
505
- os .fsync (f )
506
521
507
522
508
523
def file_add_header (filepath , header ):
@@ -514,8 +529,6 @@ def file_add_header(filepath, header):
514
529
f .seek (0 )
515
530
f .write (contents )
516
531
f .truncate ()
517
- f .flush ()
518
- os .fsync (f )
519
532
520
533
521
534
def fix_static_global_kernels (in_txt ):
@@ -875,15 +888,15 @@ def main():
875
888
txt = f .read ()
876
889
for func in functions :
877
890
# TODO - Find fix assertions in HIP for device code.
878
- txt = disable_function (txt , func , 5 )
891
+ txt = disable_function (txt , func , disablefuncmode . ASSERT )
879
892
880
893
for func in non_hip_functions :
881
894
# Disable this function on HIP stack
882
- txt = disable_function (txt , func , 2 )
895
+ txt = disable_function (txt , func , disablefuncmode . HCC_MACRO )
883
896
884
897
for func in not_on_device_functions :
885
898
# Disable this function when compiling on Device
886
- txt = disable_function (txt , func , 3 )
899
+ txt = disable_function (txt , func , disablefuncmode . DEVICE_MACRO )
887
900
888
901
f .seek (0 )
889
902
f .write (txt )
0 commit comments