Skip to content

Commit a4bd1db

Browse files
committed
Inline generated functions and update validation
Signed-off-by: Jake Tronge <[email protected]>
1 parent 8df3db5 commit a4bd1db

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

ompi/mpi/c/abi.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ class ConvertOMPIToStandard:
259259
COMM = 'ompi_convert_comm_ompi_to_standard'
260260

261261

262+
# Inline function attributes
263+
INLINE_ATTRS = '__opal_attribute_always_inline__ static inline'
264+
265+
262266
def mpi_fn_name_from_base_fn_name(name):
263267
"""Convert from a base name to the standard 'MPI_*' name."""
264268
return f'MPI_{name.capitalize()}'
@@ -318,7 +322,7 @@ def dump_lines(self, lines):
318322
self.dump(line)
319323

320324
def generate_error_convert_fn(self):
321-
self.dump(f'static inline int {ConvertFuncs.ERROR_CLASS}(int error_class)')
325+
self.dump(f'{INLINE_ATTRS} int {ConvertFuncs.ERROR_CLASS}(int error_class)')
322326
self.dump('{')
323327
lines = []
324328
lines.append('switch (error_class) {')
@@ -333,7 +337,7 @@ def generate_error_convert_fn(self):
333337

334338
def generic_convert(self, fn_name, param_name, type_, value_names):
335339
intern_type = self.mangle_name(type_)
336-
self.dump(f'static inline {type_} {fn_name}({intern_type} {param_name})')
340+
self.dump(f'{INLINE_ATTRS} {type_} {fn_name}({intern_type} {param_name})')
337341
self.dump('{')
338342
lines = []
339343
for i, value_name in enumerate(value_names):
@@ -350,7 +354,7 @@ def generic_convert(self, fn_name, param_name, type_, value_names):
350354

351355
def generic_convert_reverse(self, fn_name, param_name, type_, value_names):
352356
intern_type = self.mangle_name(type_)
353-
self.dump(f'static inline {intern_type} {fn_name}({type_} {param_name})')
357+
self.dump(f'{INLINE_ATTRS} {intern_type} {fn_name}({type_} {param_name})')
354358
self.dump('{')
355359
lines = []
356360
for i, value_name in enumerate(value_names):
@@ -388,7 +392,7 @@ def generate_win_convert_fn(self):
388392

389393
def generate_pointer_convert_fn(self, type_, fn_name, constants):
390394
abi_type = self.mangle_name(type_)
391-
self.dump(f'static inline void {fn_name}({abi_type} *ptr)')
395+
self.dump(f'{INLINE_ATTRS} void {fn_name}({abi_type} *ptr)')
392396
self.dump('{')
393397
lines = []
394398
for i, ompi_name in enumerate(constants):
@@ -411,7 +415,7 @@ def generate_file_convert_fn(self):
411415
def generate_status_convert_fn(self):
412416
type_ = 'MPI_Status'
413417
abi_type = self.mangle_name(type_)
414-
self.dump(f'static inline void {ConvertFuncs.STATUS}({abi_type} *out, {type_} *inp)')
418+
self.dump(f'{INLINE_ATTRS} void {ConvertFuncs.STATUS}({abi_type} *out, {type_} *inp)')
415419
self.dump('{')
416420
self.dump(' out->MPI_SOURCE = inp->MPI_SOURCE;')
417421
self.dump(' out->MPI_TAG = inp->MPI_TAG;')
@@ -1051,22 +1055,26 @@ class TemplateParseError(Exception):
10511055

10521056
def validate_body(body):
10531057
"""Validate the body of a template."""
1054-
# Just do a simple bracket balance test determine the bounds of the
1058+
# Just do a simple bracket balance test to determine the bounds of the
10551059
# function body. All lines after the function body should be blank. There
10561060
# are cases where this will break, such as if someone puts code all on one
10571061
# line.
10581062
bracket_balance = 0
10591063
line_count = 0
10601064
for line in body:
10611065
line = line.strip()
1066+
print(bracket_balance, line_count, line[:5])
10621067
if bracket_balance == 0 and line_count > 0 and line:
1063-
raise TemplateParserError('Extra code found in template; only one function body is allowed')
1068+
raise TemplateParseError('Extra code found in template; only one function body is allowed')
10641069

10651070
update = line.count('{') - line.count('}')
10661071
bracket_balance += update
10671072
if bracket_balance != 0:
10681073
line_count += 1
10691074

1075+
if bracket_balance != 0:
1076+
raise TemplateParseError('Mismatched brackets found in template')
1077+
10701078

10711079
class SourceTemplate:
10721080
"""Source template for a single API function."""
@@ -1182,7 +1190,7 @@ def standard_abi(base_name, template):
11821190
internal_name = f'ompi_abi_{template.prototype.name}'
11831191
internal_sig = template.prototype.signature('ompi', internal_name,
11841192
count_type='MPI_Count')
1185-
print('static inline', internal_sig)
1193+
print(INLINE_ATTRS, internal_sig)
11861194
template.print_body(func_name=base_name)
11871195

11881196
def generate_function(prototype, fn_name, internal_fn, count_type='int'):

0 commit comments

Comments
 (0)