@@ -259,6 +259,10 @@ class ConvertOMPIToStandard:
259
259
COMM = 'ompi_convert_comm_ompi_to_standard'
260
260
261
261
262
+ # Inline function attributes
263
+ INLINE_ATTRS = '__opal_attribute_always_inline__ static inline'
264
+
265
+
262
266
def mpi_fn_name_from_base_fn_name (name ):
263
267
"""Convert from a base name to the standard 'MPI_*' name."""
264
268
return f'MPI_{ name .capitalize ()} '
@@ -318,7 +322,7 @@ def dump_lines(self, lines):
318
322
self .dump (line )
319
323
320
324
def generate_error_convert_fn (self ):
321
- self .dump (f'static inline int { ConvertFuncs .ERROR_CLASS } (int error_class)' )
325
+ self .dump (f'{ INLINE_ATTRS } int { ConvertFuncs .ERROR_CLASS } (int error_class)' )
322
326
self .dump ('{' )
323
327
lines = []
324
328
lines .append ('switch (error_class) {' )
@@ -333,7 +337,7 @@ def generate_error_convert_fn(self):
333
337
334
338
def generic_convert (self , fn_name , param_name , type_ , value_names ):
335
339
intern_type = self .mangle_name (type_ )
336
- self .dump (f'static inline { type_ } { fn_name } ({ intern_type } { param_name } )' )
340
+ self .dump (f'{ INLINE_ATTRS } { type_ } { fn_name } ({ intern_type } { param_name } )' )
337
341
self .dump ('{' )
338
342
lines = []
339
343
for i , value_name in enumerate (value_names ):
@@ -350,7 +354,7 @@ def generic_convert(self, fn_name, param_name, type_, value_names):
350
354
351
355
def generic_convert_reverse (self , fn_name , param_name , type_ , value_names ):
352
356
intern_type = self .mangle_name (type_ )
353
- self .dump (f'static inline { intern_type } { fn_name } ({ type_ } { param_name } )' )
357
+ self .dump (f'{ INLINE_ATTRS } { intern_type } { fn_name } ({ type_ } { param_name } )' )
354
358
self .dump ('{' )
355
359
lines = []
356
360
for i , value_name in enumerate (value_names ):
@@ -388,7 +392,7 @@ def generate_win_convert_fn(self):
388
392
389
393
def generate_pointer_convert_fn (self , type_ , fn_name , constants ):
390
394
abi_type = self .mangle_name (type_ )
391
- self .dump (f'static inline void { fn_name } ({ abi_type } *ptr)' )
395
+ self .dump (f'{ INLINE_ATTRS } void { fn_name } ({ abi_type } *ptr)' )
392
396
self .dump ('{' )
393
397
lines = []
394
398
for i , ompi_name in enumerate (constants ):
@@ -411,7 +415,7 @@ def generate_file_convert_fn(self):
411
415
def generate_status_convert_fn (self ):
412
416
type_ = 'MPI_Status'
413
417
abi_type = self .mangle_name (type_ )
414
- self .dump (f'static inline void { ConvertFuncs .STATUS } ({ abi_type } *out, { type_ } *inp)' )
418
+ self .dump (f'{ INLINE_ATTRS } void { ConvertFuncs .STATUS } ({ abi_type } *out, { type_ } *inp)' )
415
419
self .dump ('{' )
416
420
self .dump (' out->MPI_SOURCE = inp->MPI_SOURCE;' )
417
421
self .dump (' out->MPI_TAG = inp->MPI_TAG;' )
@@ -1051,22 +1055,26 @@ class TemplateParseError(Exception):
1051
1055
1052
1056
def validate_body (body ):
1053
1057
"""Validate the body of a template."""
1054
- # Just do a simple bracket balance test determine the bounds of the
1058
+ # Just do a simple bracket balance test to determine the bounds of the
1055
1059
# function body. All lines after the function body should be blank. There
1056
1060
# are cases where this will break, such as if someone puts code all on one
1057
1061
# line.
1058
1062
bracket_balance = 0
1059
1063
line_count = 0
1060
1064
for line in body :
1061
1065
line = line .strip ()
1066
+ print (bracket_balance , line_count , line [:5 ])
1062
1067
if bracket_balance == 0 and line_count > 0 and line :
1063
- raise TemplateParserError ('Extra code found in template; only one function body is allowed' )
1068
+ raise TemplateParseError ('Extra code found in template; only one function body is allowed' )
1064
1069
1065
1070
update = line .count ('{' ) - line .count ('}' )
1066
1071
bracket_balance += update
1067
1072
if bracket_balance != 0 :
1068
1073
line_count += 1
1069
1074
1075
+ if bracket_balance != 0 :
1076
+ raise TemplateParseError ('Mismatched brackets found in template' )
1077
+
1070
1078
1071
1079
class SourceTemplate :
1072
1080
"""Source template for a single API function."""
@@ -1182,7 +1190,7 @@ def standard_abi(base_name, template):
1182
1190
internal_name = f'ompi_abi_{ template .prototype .name } '
1183
1191
internal_sig = template .prototype .signature ('ompi' , internal_name ,
1184
1192
count_type = 'MPI_Count' )
1185
- print ('static inline' , internal_sig )
1193
+ print (INLINE_ATTRS , internal_sig )
1186
1194
template .print_body (func_name = base_name )
1187
1195
1188
1196
def generate_function (prototype , fn_name , internal_fn , count_type = 'int' ):
0 commit comments